feat: Implement time-aware GPT-2 for patient event prediction
This commit introduces a complete framework for training a temporal GPT-2 model on sequential patient event data. Key components include: - `models.py`: - `TimeAwareGPT2`: A custom GPT-2 model that incorporates temporal information through a time-based causal attention mask and a sinusoidal age encoding for positional information. - `AgeSinusoidalEncoding`: A module for creating time-based positional embeddings. - `CombinedLoss`: A two-part loss function combining cross-entropy for event prediction and a survival loss for event timing. - `utils.py`: - `PatientEventDataset`: A PyTorch Dataset class to process, batch, and load patient event sequences, including imputation of "no event" gaps and padding/truncation. - `train.py`: - A comprehensive training script that initializes the model, data loaders, and loss function. - Implements a training loop with a cosine annealing learning rate scheduler, validation, and early stopping based on validation loss. - `prepare_data.py`: - Script for preprocessing raw UK Biobank data into a format suitable for the model. - `GEMINI.md`: - Project documentation outlining the structure, coding style, and framework.
This commit is contained in:
170
train.py
Normal file
170
train.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import math
|
||||
import tqdm
|
||||
|
||||
from models import TimeAwareGPT2, CombinedLoss
|
||||
from utils import PatientEventDataset
|
||||
|
||||
# --- Configuration ---
|
||||
class TrainConfig:
|
||||
# Data parameters
|
||||
train_data_path = 'ukb_real_train.bin'
|
||||
val_data_path = 'ukb_real_val.bin'
|
||||
block_length = 256 # Sequence length
|
||||
|
||||
# Model parameters
|
||||
n_embd = 256
|
||||
n_layer = 8
|
||||
n_head = 8
|
||||
pdrop = 0.1
|
||||
token_pdrop = 0.1
|
||||
|
||||
# Training parameters
|
||||
max_epoch = 200
|
||||
batch_size = 128
|
||||
lr_initial = 6e-4
|
||||
lr_final = 6e-5
|
||||
warmup_epochs = 10
|
||||
early_stopping_patience = 5
|
||||
|
||||
# Loss parameters
|
||||
# 0 = padding, 1 = "no event"
|
||||
ignored_token_ids = [0, 1]
|
||||
|
||||
# System parameters
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# --- Main Training Script ---
|
||||
def main():
|
||||
config = TrainConfig()
|
||||
|
||||
# --- 1. Data Loading ---
|
||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
|
||||
# Infer vocab_size from the data (max label + 1)
|
||||
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
||||
print(f"Inferred vocabulary size: {vocab_size}")
|
||||
|
||||
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
|
||||
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
||||
|
||||
# --- 2. Model, Optimizer, and Loss Initialization ---
|
||||
print(f"Initializing model on {config.device}...")
|
||||
model = TimeAwareGPT2(
|
||||
vocab_size=vocab_size,
|
||||
n_embd=config.n_embd,
|
||||
n_layer=config.n_layer,
|
||||
n_head=config.n_head,
|
||||
pdrop=config.pdrop,
|
||||
token_pdrop=config.token_pdrop
|
||||
).to(config.device)
|
||||
|
||||
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
||||
|
||||
loss_fn = CombinedLoss(config.ignored_token_ids)
|
||||
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
||||
|
||||
# --- 3. Training Loop ---
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
print("Starting training...")
|
||||
for epoch in range(config.max_epoch):
|
||||
# --- Learning Rate Scheduling ---
|
||||
if epoch < config.warmup_epochs:
|
||||
lr = config.lr_initial
|
||||
else:
|
||||
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
|
||||
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# --- Training Phase ---
|
||||
model.train()
|
||||
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
|
||||
train_steps = 0
|
||||
|
||||
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
|
||||
for event_seq, time_seq in pbar:
|
||||
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||
|
||||
# Prepare inputs and targets
|
||||
input_events = event_seq[:, :-1]
|
||||
input_times = time_seq[:, :-1]
|
||||
target_events = event_seq[:, 1:]
|
||||
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||
|
||||
# Forward pass
|
||||
logits = model(input_events, input_times)
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
loss = loss_ce + loss_survival
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss_ce_acc += loss_ce.item()
|
||||
train_loss_surv_acc += loss_survival.item()
|
||||
train_steps += 1
|
||||
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
|
||||
|
||||
avg_train_loss_ce = train_loss_ce_acc / train_steps
|
||||
avg_train_loss_surv = train_loss_surv_acc / train_steps
|
||||
|
||||
# --- Validation Phase ---
|
||||
model.eval()
|
||||
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
|
||||
val_steps = 0
|
||||
|
||||
with torch.no_grad():
|
||||
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
|
||||
for event_seq, time_seq in pbar_val:
|
||||
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
|
||||
|
||||
input_events = event_seq[:, :-1]
|
||||
input_times = time_seq[:, :-1]
|
||||
target_events = event_seq[:, 1:]
|
||||
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
||||
|
||||
logits = model(input_events, input_times)
|
||||
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
||||
|
||||
val_loss_ce_acc += loss_ce.item()
|
||||
val_loss_surv_acc += loss_survival.item()
|
||||
val_steps += 1
|
||||
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
|
||||
|
||||
avg_val_loss_ce = val_loss_ce_acc / val_steps
|
||||
avg_val_loss_surv = val_loss_surv_acc / val_steps
|
||||
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
|
||||
|
||||
print(f"Epoch {epoch+1} Summary: \n"
|
||||
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
|
||||
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
|
||||
f" Learning Rate: {lr:.6f}")
|
||||
|
||||
# --- Early Stopping Check ---
|
||||
if total_val_loss < best_val_loss:
|
||||
best_val_loss = total_val_loss
|
||||
patience_counter = 0
|
||||
print(f"Validation loss improved to {best_val_loss:.4f}. Resetting patience.")
|
||||
else:
|
||||
patience_counter += 1
|
||||
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
|
||||
|
||||
if patience_counter >= config.early_stopping_patience:
|
||||
print("\nEarly stopping triggered due to no improvement in validation loss.")
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user