Refactor: Improve attention mechanism and early stopping

- Refactor the self-attention mechanism in `models.py` to use `nn.MultiheadAttention` for better performance and clarity.
- Disable early stopping check during warmup epochs in `train.py` to improve training stability.
This commit is contained in:
2025-10-16 15:57:27 +08:00
parent 8a757a8b1d
commit 4181ead03a
2 changed files with 80 additions and 51 deletions

View File

@@ -2,57 +2,15 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import Tuple from typing import Tuple
import math
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
"""
def __init__(self, n_embd: int, n_head: int, pdrop: float):
super().__init__()
assert n_embd % n_head == 0
# key, query, value projections for all heads
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
# output projection
self.c_proj = nn.Linear(n_embd, n_embd)
# regularization
self.attn_dropout = nn.Dropout(pdrop)
self.resid_dropout = nn.Dropout(pdrop)
self.n_head = n_head
self.n_embd = n_embd
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
B, L, D = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
q = q.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
v = v.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs)
# causal self-attention; Self-attend: (B, nh, L, hs) x (B, nh, hs, L) -> (B, nh, L, L)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# Apply the time-based causal mask
att = att.masked_fill(custom_mask.unsqueeze(1) == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, L, L) x (B, nh, L, hs) -> (B, nh, L, hs)
y = y.transpose(1, 2).contiguous().view(B, L, D) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
class Block(nn.Module): class Block(nn.Module):
""" an unassuming Transformer block """ """ an unassuming Transformer block """
def __init__(self, n_embd: int, n_head: int, pdrop: float): def __init__(self, n_embd: int, n_head: int, pdrop: float):
super().__init__() super().__init__()
self.n_head = n_head
self.ln_1 = nn.LayerNorm(n_embd) self.ln_1 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, n_head, pdrop) self.attn = nn.MultiheadAttention(n_embd, n_head, dropout=pdrop, batch_first=True)
self.ln_2 = nn.LayerNorm(n_embd) self.ln_2 = nn.LayerNorm(n_embd)
self.mlp = nn.ModuleDict(dict( self.mlp = nn.ModuleDict(dict(
c_fc = nn.Linear(n_embd, 4 * n_embd), c_fc = nn.Linear(n_embd, 4 * n_embd),
@@ -62,9 +20,16 @@ class Block(nn.Module):
)) ))
m = self.mlp m = self.mlp
self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward
self.resid_dropout = nn.Dropout(pdrop)
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln_1(x), custom_mask=custom_mask) normed_x = self.ln_1(x)
attn_mask = ~custom_mask
attn_mask = attn_mask.repeat_interleave(self.n_head, dim=0)
attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False)
x = x + self.resid_dropout(attn_output)
x = x + self.mlpf(self.ln_2(x)) x = x + self.mlpf(self.ln_2(x))
return x return x
@@ -190,13 +155,13 @@ class TimeAwareGPT2(nn.Module):
# 5. Generate attention mask # 5. Generate attention mask
# The attention mask combines two conditions: # The attention mask combines two conditions:
# a) Time-based causality: A token i can attend to a token j only if time_seq[j] < time_seq[i]. # a) Time-based causality: A token i can attend to a token j only if time_seq[j] <= time_seq[i].
# b) Padding mask: Do not attend to positions where the event token is 0. # b) Padding mask: Do not attend to positions where the event token is 0.
# a) Time-based causal mask # a) Time-based causal mask
t_i = time_seq.unsqueeze(-1) # (B, L, 1) t_i = time_seq.unsqueeze(-1) # (B, L, 1)
t_j = time_seq.unsqueeze(1) # (B, 1, L) t_j = time_seq.unsqueeze(1) # (B, 1, L)
time_mask = (t_j < t_i) time_mask = (t_j <= t_i)
# b) Padding mask (prevents attending to key positions that are padding) # b) Padding mask (prevents attending to key positions that are padding)
padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L) padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L)

View File

@@ -5,6 +5,7 @@ from torch.utils.data import DataLoader
import numpy as np import numpy as np
import math import math
import tqdm import tqdm
import matplotlib.pyplot as plt
from models import TimeAwareGPT2, CombinedLoss from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset from utils import PatientEventDataset
@@ -14,7 +15,7 @@ class TrainConfig:
# Data parameters # Data parameters
train_data_path = 'ukb_real_train.bin' train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin' val_data_path = 'ukb_real_val.bin'
block_length = 256 # Sequence length block_length = 24 # Sequence length
# Model parameters # Model parameters
n_embd = 256 n_embd = 256
@@ -76,6 +77,11 @@ def main():
# --- 3. Training Loop --- # --- 3. Training Loop ---
best_val_loss = float('inf') best_val_loss = float('inf')
patience_counter = 0 patience_counter = 0
# Lists to store losses
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
print("Starting training...") print("Starting training...")
for epoch in range(config.max_epoch): for epoch in range(config.max_epoch):
# --- Learning Rate Scheduling --- # --- Learning Rate Scheduling ---
@@ -120,6 +126,9 @@ def main():
avg_train_loss_ce = train_loss_ce_acc / train_steps avg_train_loss_ce = train_loss_ce_acc / train_steps
avg_train_loss_surv = train_loss_surv_acc / train_steps avg_train_loss_surv = train_loss_surv_acc / train_steps
train_losses_ce.append(avg_train_loss_ce)
train_losses_surv.append(avg_train_loss_surv)
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
# --- Validation Phase --- # --- Validation Phase ---
model.eval() model.eval()
@@ -147,6 +156,9 @@ def main():
avg_val_loss_ce = val_loss_ce_acc / val_steps avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv total_val_loss = avg_val_loss_ce + avg_val_loss_surv
val_losses_ce.append(avg_val_loss_ce)
val_losses_surv.append(avg_val_loss_surv)
val_losses_total.append(total_val_loss)
print(f"Epoch {epoch+1} Summary: \n" print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
@@ -157,8 +169,10 @@ def main():
if total_val_loss < best_val_loss: if total_val_loss < best_val_loss:
best_val_loss = total_val_loss best_val_loss = total_val_loss
patience_counter = 0 patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Resetting patience.") print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
else: else:
if epoch >= config.warmup_epochs:
patience_counter += 1 patience_counter += 1
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}") print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
@@ -166,5 +180,55 @@ def main():
print("\nEarly stopping triggered due to no improvement in validation loss.") print("\nEarly stopping triggered due to no improvement in validation loss.")
break break
# --- Save Best Model at the End ---
if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
print("Saving final best model to best_model.pt")
torch.save(model.state_dict(), 'best_model.pt')
else:
print("\nTraining finished. No best model to save as validation loss never improved.")
# --- Plot and Save Loss Curves ---
num_epochs = len(train_losses_total)
epochs = range(1, num_epochs + 1)
plt.figure(figsize=(18, 5))
# Plot CE Loss
plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses_ce, label='Train CE')
plt.plot(epochs, val_losses_ce, label='Val CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Survival Loss
plt.subplot(1, 3, 2)
plt.plot(epochs, train_losses_surv, label='Train Survival')
plt.plot(epochs, val_losses_surv, label='Val Survival')
plt.title('Survival Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Total Loss
plt.subplot(1, 3, 3)
plt.plot(epochs, train_losses_total, label='Train Total')
plt.plot(epochs, val_losses_total, label='Val Total')
plt.title('Total Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves.png')
print("\nLoss curves saved to loss_curves.png")
if __name__ == '__main__': if __name__ == '__main__':
main() main()