feat: Update model and training parameters
In `models.py`: - Change temporal attention mask to be strictly causal (`<` instead of `<=`). - Add self-attention for the first token in a sequence to prevent NaNs. In `train.py`: - Update hyperparameters: - `block_length`: 24 -> 48 - `n_embd`: 256 -> 120 - `n_layer`: 8 -> 12 - `n_head`: 8 -> 12
This commit is contained in:
11
models.py
11
models.py
@@ -155,13 +155,13 @@ class TimeAwareGPT2(nn.Module):
|
||||
|
||||
# 5. Generate attention mask
|
||||
# The attention mask combines two conditions:
|
||||
# a) Time-based causality: A token i can attend to a token j only if time_seq[j] <= time_seq[i].
|
||||
# a) Time-based causality: A token i can attend to a token j only if time_seq[j] < time_seq[i].
|
||||
# b) Padding mask: Do not attend to positions where the event token is 0.
|
||||
|
||||
# a) Time-based causal mask
|
||||
t_i = time_seq.unsqueeze(-1) # (B, L, 1)
|
||||
t_j = time_seq.unsqueeze(1) # (B, 1, L)
|
||||
time_mask = (t_j <= t_i)
|
||||
time_mask = (t_j < t_i)
|
||||
|
||||
# b) Padding mask (prevents attending to key positions that are padding)
|
||||
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)
|
||||
@@ -170,6 +170,13 @@ class TimeAwareGPT2(nn.Module):
|
||||
# it's in the past (time_mask) AND it's not a padding token (padding_mask).
|
||||
combined_mask = time_mask & padding_mask
|
||||
|
||||
# Forcibly allow a non-padding token to attend to itself if it cannot attend to any other token.
|
||||
# This prevents NaN issues in the attention mechanism for the first token in a sequence.
|
||||
is_row_all_zero = ~combined_mask.any(dim=-1)
|
||||
is_not_padding = (event_seq != 0)
|
||||
force_self_attention = is_row_all_zero & is_not_padding
|
||||
combined_mask.diagonal(dim1=-2, dim2=-1)[force_self_attention] = True
|
||||
|
||||
# 6. Pass through transformer blocks
|
||||
for block in self.blocks:
|
||||
x = block(x, custom_mask=combined_mask)
|
||||
|
8
train.py
8
train.py
@@ -15,12 +15,12 @@ class TrainConfig:
|
||||
# Data parameters
|
||||
train_data_path = 'ukb_real_train.bin'
|
||||
val_data_path = 'ukb_real_val.bin'
|
||||
block_length = 24 # Sequence length
|
||||
block_length = 48 # Sequence length
|
||||
|
||||
# Model parameters
|
||||
n_embd = 256
|
||||
n_layer = 8
|
||||
n_head = 8
|
||||
n_embd = 120
|
||||
n_layer = 12
|
||||
n_head = 12
|
||||
pdrop = 0.1
|
||||
token_pdrop = 0.1
|
||||
|
||||
|
Reference in New Issue
Block a user