feat: Update model and training parameters

In `models.py`:
- Change temporal attention mask to be strictly causal (`<` instead of `<=`).
- Add self-attention for the first token in a sequence to prevent NaNs.

In `train.py`:
- Update hyperparameters:
  - `block_length`: 24 -> 48
  - `n_embd`: 256 -> 120
  - `n_layer`: 8 -> 12
  - `n_head`: 8 -> 12
This commit is contained in:
2025-10-16 18:50:15 +08:00
parent e2495f43b0
commit cb7575a229
2 changed files with 13 additions and 6 deletions

View File

@@ -155,13 +155,13 @@ class TimeAwareGPT2(nn.Module):
# 5. Generate attention mask # 5. Generate attention mask
# The attention mask combines two conditions: # The attention mask combines two conditions:
# a) Time-based causality: A token i can attend to a token j only if time_seq[j] <= time_seq[i]. # a) Time-based causality: A token i can attend to a token j only if time_seq[j] < time_seq[i].
# b) Padding mask: Do not attend to positions where the event token is 0. # b) Padding mask: Do not attend to positions where the event token is 0.
# a) Time-based causal mask # a) Time-based causal mask
t_i = time_seq.unsqueeze(-1) # (B, L, 1) t_i = time_seq.unsqueeze(-1) # (B, L, 1)
t_j = time_seq.unsqueeze(1) # (B, 1, L) t_j = time_seq.unsqueeze(1) # (B, 1, L)
time_mask = (t_j <= t_i) time_mask = (t_j < t_i)
# b) Padding mask (prevents attending to key positions that are padding) # b) Padding mask (prevents attending to key positions that are padding)
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L) padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
@@ -170,6 +170,13 @@ class TimeAwareGPT2(nn.Module):
# it's in the past (time_mask) AND it's not a padding token (padding_mask). # it's in the past (time_mask) AND it's not a padding token (padding_mask).
combined_mask = time_mask & padding_mask combined_mask = time_mask & padding_mask
# Forcibly allow a non-padding token to attend to itself if it cannot attend to any other token.
# This prevents NaN issues in the attention mechanism for the first token in a sequence.
is_row_all_zero = ~combined_mask.any(dim=-1)
is_not_padding = (event_seq != 0)
force_self_attention = is_row_all_zero & is_not_padding
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
# 6. Pass through transformer blocks # 6. Pass through transformer blocks
for block in self.blocks: for block in self.blocks:
x = block(x, custom_mask=combined_mask) x = block(x, custom_mask=combined_mask)

View File

@@ -15,12 +15,12 @@ class TrainConfig:
# Data parameters # Data parameters
train_data_path = 'ukb_real_train.bin' train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin' val_data_path = 'ukb_real_val.bin'
block_length = 24 # Sequence length block_length = 48 # Sequence length
# Model parameters # Model parameters
n_embd = 256 n_embd = 120
n_layer = 8 n_layer = 12
n_head = 8 n_head = 12
pdrop = 0.1 pdrop = 0.1
token_pdrop = 0.1 token_pdrop = 0.1