Enhance DataLoader configuration and improve tensor transfer efficiency in Trainer class

This commit is contained in:
2026-01-08 13:20:32 +08:00
parent 5382f9f159
commit 01a96d37ea
4 changed files with 81 additions and 30 deletions

View File

@@ -263,13 +263,24 @@ def _build_time_padding_mask(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
) -> torch.Tensor:
t_i = time_seq.unsqueeze(-1)
t_j = time_seq.unsqueeze(1)
time_mask = (t_j <= t_i) # allow attending only to past or current
key_is_valid = (event_seq != 0) # disallow padded positions
allowed = time_mask & key_is_valid.unsqueeze(1)
attn_mask = ~allowed # True means mask for scaled_dot_product_attention
return attn_mask.unsqueeze(1) # (B, 1, L, L)
B, L = event_seq.shape
device = event_seq.device
key_is_valid = (event_seq != 0)
cache = getattr(_build_time_padding_mask, "_cache", None)
if cache is None:
cache = {}
setattr(_build_time_padding_mask, "_cache", cache)
cache_key = (str(device), L)
causal = cache.get(cache_key)
if causal is None:
causal = torch.ones(L, L, device=device, dtype=torch.bool).triu(1)
cache[cache_key] = causal
causal = causal.unsqueeze(0).unsqueeze(0) # (1,1,L,L)
key_pad = (~key_is_valid).unsqueeze(1).unsqueeze(2) # (B,1,1,L)
attn_mask = causal | key_pad # (B,1,L,L)
return attn_mask
class DelphiFork(nn.Module):