diff --git a/models.py b/models.py index 32021b2..3309fc0 100644 --- a/models.py +++ b/models.py @@ -155,13 +155,13 @@ class TimeAwareGPT2(nn.Module): # 5. Generate attention mask # The attention mask combines two conditions: - # a) Time-based causality: A token i can attend to a token j only if time_seq[j] <= time_seq[i]. + # a) Time-based causality: A token i can attend to a token j only if time_seq[j] < time_seq[i]. # b) Padding mask: Do not attend to positions where the event token is 0. # a) Time-based causal mask t_i = time_seq.unsqueeze(-1) # (B, L, 1) t_j = time_seq.unsqueeze(1) # (B, 1, L) - time_mask = (t_j <= t_i) + time_mask = (t_j < t_i) # b) Padding mask (prevents attending to key positions that are padding) padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L) @@ -170,6 +170,13 @@ class TimeAwareGPT2(nn.Module): # it's in the past (time_mask) AND it's not a padding token (padding_mask). combined_mask = time_mask & padding_mask + # Forcibly allow a non-padding token to attend to itself if it cannot attend to any other token. + # This prevents NaN issues in the attention mechanism for the first token in a sequence. + is_row_all_zero = ~combined_mask.any(dim=-1) + is_not_padding = (event_seq != 0) + force_self_attention = is_row_all_zero & is_not_padding + combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True + # 6. Pass through transformer blocks for block in self.blocks: x = block(x, custom_mask=combined_mask) diff --git a/train.py b/train.py index 808be72..cebc6cf 100644 --- a/train.py +++ b/train.py @@ -15,12 +15,12 @@ class TrainConfig: # Data parameters train_data_path = 'ukb_real_train.bin' val_data_path = 'ukb_real_val.bin' - block_length = 24 # Sequence length + block_length = 48 # Sequence length # Model parameters - n_embd = 256 - n_layer = 8 - n_head = 8 + n_embd = 120 + n_layer = 12 + n_head = 12 pdrop = 0.1 token_pdrop = 0.1