update attn mask

This commit is contained in:
2025-10-16 18:29:48 +08:00
parent eec406d79f
commit 6e0713048a

View File

@@ -159,20 +159,27 @@ class TimeAwareGPT2(nn.Module):
# b) Padding mask: Do not attend to positions where the event token is 0. # b) Padding mask: Do not attend to positions where the event token is 0.
# a) Time-based causal mask # a) Time-based causal mask
t_i = time_seq.unsqueeze(-1) # (B, L, 1) # t_i = time_seq.unsqueeze(-1) # (B, L, 1)
t_j = time_seq.unsqueeze(1) # (B, 1, L) # t_j = time_seq.unsqueeze(1) # (B, 1, L)
time_mask = (t_j <= t_i) # time_mask = (t_j <= t_i)
# b) Padding mask (prevents attending to key positions that are padding) # b) Padding mask (prevents attending to key positions that are padding)
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L) # padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
# Combine the masks. A position (j) can be attended to by a query (i) only if # Combine the masks. A position (j) can be attended to by a query (i) only if
# it's in the past (time_mask) AND it's not a padding token (padding_mask). # it's in the past (time_mask) AND it's not a padding token (padding_mask).
combined_mask = time_mask & padding_mask # combined_mask = time_mask & padding_mask
attn_mask = (x>0).view(x.size(0), 1, 1, x.size(1)) * (x>0).view(x.size(0),1,x.size(1),1) # Do not attend to padded positions
attn_mask *= torch.tril(torch.ones(x.size(1),x.size(1)))[None,None,:,:] > 0 #self.transformer.h[0].attn.bias[:,:,:x.size(1),:x.size(1)] > 0
attn_mask *= ((age.view(x.size(0),1,1,x.size(1)) != targets_age.view(x.size(0),1,x.size(1),1))) # Mask co-occuring tokens
attn_mask += (attn_mask.sum(-1, keepdim=True)==0) * torch.diag(torch.ones(x.size(1))) > 0
attn_mask = attn_mask + (x==0).view(x.size(0), 1, 1, x.size(1)) * torch.diag(torch.ones(x.size(1))) > 0 # Except for padding
attn_mask *= torch.tril(torch.ones(x.size(1),x.size(1)))[None,None,:,:] > 0 #self.transformer.h[0].attn.bias[:,:,:x.size(1),:x.size(1)] > 0
# 6. Pass through transformer blocks # 6. Pass through transformer blocks
for block in self.blocks: for block in self.blocks:
x = block(x, custom_mask=combined_mask) x = block(x, custom_mask=attn_mask)
# 7. Final layer norm and projection to vocab size # 7. Final layer norm and projection to vocab size
x = self.ln_f(x) x = self.ln_f(x)