From e2495f43b04c0f4dafb53823f32afee5dc8dd000 Mon Sep 17 00:00:00 2001 From: jiarui_li Date: Thu, 16 Oct 2025 18:37:55 +0800 Subject: [PATCH] revert 6e0713048ab52c17c05fb42ffb7949f0e32591a2 revert update attn mask --- models.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/models.py b/models.py index 17e080f..32021b2 100644 --- a/models.py +++ b/models.py @@ -159,27 +159,20 @@ class TimeAwareGPT2(nn.Module): # b) Padding mask: Do not attend to positions where the event token is 0. # a) Time-based causal mask - # t_i = time_seq.unsqueeze(-1) # (B, L, 1) - # t_j = time_seq.unsqueeze(1) # (B, 1, L) - # time_mask = (t_j <= t_i) + t_i = time_seq.unsqueeze(-1) # (B, L, 1) + t_j = time_seq.unsqueeze(1) # (B, 1, L) + time_mask = (t_j <= t_i) # b) Padding mask (prevents attending to key positions that are padding) - # padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L) + padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L) # Combine the masks. A position (j) can be attended to by a query (i) only if # it's in the past (time_mask) AND it's not a padding token (padding_mask). - # combined_mask = time_mask & padding_mask - - attn_mask = (x>0).view(x.size(0), 1, 1, x.size(1)) * (x>0).view(x.size(0),1,x.size(1),1) # Do not attend to padded positions - attn_mask *= torch.tril(torch.ones(x.size(1),x.size(1)))[None,None,:,:] > 0 #self.transformer.h[0].attn.bias[:,:,:x.size(1),:x.size(1)] > 0 - attn_mask *= ((age.view(x.size(0),1,1,x.size(1)) != targets_age.view(x.size(0),1,x.size(1),1))) # Mask co-occuring tokens - attn_mask += (attn_mask.sum(-1, keepdim=True)==0) * torch.diag(torch.ones(x.size(1))) > 0 - attn_mask = attn_mask + (x==0).view(x.size(0), 1, 1, x.size(1)) * torch.diag(torch.ones(x.size(1))) > 0 # Except for padding - attn_mask *= torch.tril(torch.ones(x.size(1),x.size(1)))[None,None,:,:] > 0 #self.transformer.h[0].attn.bias[:,:,:x.size(1),:x.size(1)] > 0 + combined_mask = time_mask & padding_mask # 6. Pass through transformer blocks for block in self.blocks: - x = block(x, custom_mask=attn_mask) + x = block(x, custom_mask=combined_mask) # 7. Final layer norm and projection to vocab size x = self.ln_f(x)