19
models.py
19
models.py
@@ -159,27 +159,20 @@ class TimeAwareGPT2(nn.Module):
|
|||||||
# b) Padding mask: Do not attend to positions where the event token is 0.
|
# b) Padding mask: Do not attend to positions where the event token is 0.
|
||||||
|
|
||||||
# a) Time-based causal mask
|
# a) Time-based causal mask
|
||||||
# t_i = time_seq.unsqueeze(-1) # (B, L, 1)
|
t_i = time_seq.unsqueeze(-1) # (B, L, 1)
|
||||||
# t_j = time_seq.unsqueeze(1) # (B, 1, L)
|
t_j = time_seq.unsqueeze(1) # (B, 1, L)
|
||||||
# time_mask = (t_j <= t_i)
|
time_mask = (t_j <= t_i)
|
||||||
|
|
||||||
# b) Padding mask (prevents attending to key positions that are padding)
|
# b) Padding mask (prevents attending to key positions that are padding)
|
||||||
# padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
|
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
|
||||||
|
|
||||||
# Combine the masks. A position (j) can be attended to by a query (i) only if
|
# Combine the masks. A position (j) can be attended to by a query (i) only if
|
||||||
# it's in the past (time_mask) AND it's not a padding token (padding_mask).
|
# it's in the past (time_mask) AND it's not a padding token (padding_mask).
|
||||||
# combined_mask = time_mask & padding_mask
|
combined_mask = time_mask & padding_mask
|
||||||
|
|
||||||
attn_mask = (x>0).view(x.size(0), 1, 1, x.size(1)) * (x>0).view(x.size(0),1,x.size(1),1) # Do not attend to padded positions
|
|
||||||
attn_mask *= torch.tril(torch.ones(x.size(1),x.size(1)))[None,None,:,:] > 0 #self.transformer.h[0].attn.bias[:,:,:x.size(1),:x.size(1)] > 0
|
|
||||||
attn_mask *= ((age.view(x.size(0),1,1,x.size(1)) != targets_age.view(x.size(0),1,x.size(1),1))) # Mask co-occuring tokens
|
|
||||||
attn_mask += (attn_mask.sum(-1, keepdim=True)==0) * torch.diag(torch.ones(x.size(1))) > 0
|
|
||||||
attn_mask = attn_mask + (x==0).view(x.size(0), 1, 1, x.size(1)) * torch.diag(torch.ones(x.size(1))) > 0 # Except for padding
|
|
||||||
attn_mask *= torch.tril(torch.ones(x.size(1),x.size(1)))[None,None,:,:] > 0 #self.transformer.h[0].attn.bias[:,:,:x.size(1),:x.size(1)] > 0
|
|
||||||
|
|
||||||
# 6. Pass through transformer blocks
|
# 6. Pass through transformer blocks
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, custom_mask=attn_mask)
|
x = block(x, custom_mask=combined_mask)
|
||||||
|
|
||||||
# 7. Final layer norm and projection to vocab size
|
# 7. Final layer norm and projection to vocab size
|
||||||
x = self.ln_f(x)
|
x = self.ln_f(x)
|
||||||
|
Reference in New Issue
Block a user