Enhance DataLoader configuration and improve tensor transfer efficiency in Trainer class
This commit is contained in:
25
model.py
25
model.py
@@ -263,13 +263,24 @@ def _build_time_padding_mask(
|
||||
event_seq: torch.Tensor,
|
||||
time_seq: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_i = time_seq.unsqueeze(-1)
|
||||
t_j = time_seq.unsqueeze(1)
|
||||
time_mask = (t_j <= t_i) # allow attending only to past or current
|
||||
key_is_valid = (event_seq != 0) # disallow padded positions
|
||||
allowed = time_mask & key_is_valid.unsqueeze(1)
|
||||
attn_mask = ~allowed # True means mask for scaled_dot_product_attention
|
||||
return attn_mask.unsqueeze(1) # (B, 1, L, L)
|
||||
B, L = event_seq.shape
|
||||
device = event_seq.device
|
||||
key_is_valid = (event_seq != 0)
|
||||
|
||||
cache = getattr(_build_time_padding_mask, "_cache", None)
|
||||
if cache is None:
|
||||
cache = {}
|
||||
setattr(_build_time_padding_mask, "_cache", cache)
|
||||
cache_key = (str(device), L)
|
||||
causal = cache.get(cache_key)
|
||||
if causal is None:
|
||||
causal = torch.ones(L, L, device=device, dtype=torch.bool).triu(1)
|
||||
cache[cache_key] = causal
|
||||
|
||||
causal = causal.unsqueeze(0).unsqueeze(0) # (1,1,L,L)
|
||||
key_pad = (~key_is_valid).unsqueeze(1).unsqueeze(2) # (B,1,1,L)
|
||||
attn_mask = causal | key_pad # (B,1,L,L)
|
||||
return attn_mask
|
||||
|
||||
|
||||
class DelphiFork(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user