import torch import torch.nn as nn from torch.optim import Adam from torch.utils.data import DataLoader import numpy as np import math import tqdm import matplotlib.pyplot as plt from models import TimeAwareGPT2, CombinedLoss from utils import PatientEventDataset # --- Configuration --- class TrainConfig: # Data parameters train_data_path = 'ukb_real_train.bin' val_data_path = 'ukb_real_val.bin' block_length = 24 # Sequence length # Model parameters n_embd = 256 n_layer = 8 n_head = 8 pdrop = 0.1 token_pdrop = 0.1 # Training parameters max_epoch = 200 batch_size = 128 lr_initial = 6e-4 lr_final = 6e-5 warmup_epochs = 10 early_stopping_patience = 5 # Loss parameters # 0 = padding, 1 = "no event" ignored_token_ids = [0, 1] # System parameters device = 'cuda' if torch.cuda.is_available() else 'cpu' # --- Main Training Script --- def main(): config = TrainConfig() device = torch.device(config.device) # --- 1. Data Loading --- print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) # Infer vocab_size from the data (max label + 1) vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 print(f"Inferred vocabulary size: {vocab_size}") train_dataset = PatientEventDataset(train_data_arr, config.block_length) val_dataset = PatientEventDataset(val_data_arr, config.block_length) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True) # --- 2. Model, Optimizer, and Loss Initialization --- print(f"Initializing model on {device}...") model = TimeAwareGPT2( vocab_size=vocab_size, n_embd=config.n_embd, n_layer=config.n_layer, n_head=config.n_head, pdrop=config.pdrop, token_pdrop=config.token_pdrop ) # --- Multi-GPU Support --- if torch.cuda.device_count() > 1: print(f"Using {torch.cuda.device_count()} GPUs!") model = nn.DataParallel(model) model.to(device) print(f"Model initialized with {model.module.get_num_params() if isinstance(model, nn.DataParallel) else model.get_num_params():.2f}M trainable parameters.") loss_fn = CombinedLoss(config.ignored_token_ids) optimizer = Adam(model.parameters(), lr=config.lr_initial) # --- 3. Training Loop --- best_val_loss = float('inf') patience_counter = 0 # Lists to store losses train_losses_ce, train_losses_surv, train_losses_total = [], [], [] val_losses_ce, val_losses_surv, val_losses_total = [], [], [] print("Starting training...") for epoch in range(config.max_epoch): # --- Learning Rate Scheduling --- if epoch < config.warmup_epochs: lr = config.lr_initial else: progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) for param_group in optimizer.param_groups: param_group['lr'] = lr # --- Training Phase --- model.train() train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0 train_steps = 0 pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") for event_seq, time_seq in pbar: event_seq, time_seq = event_seq.to(device), time_seq.to(device) # Prepare inputs and targets input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() # Forward pass logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss = loss_ce + loss_survival # When using DataParallel, loss is a vector of losses from each GPU. # We need to average them to get a single scalar loss. if isinstance(model, nn.DataParallel): loss = loss.mean() # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() train_loss_ce_acc += loss_ce.item() train_loss_surv_acc += loss_survival.item() train_steps += 1 pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) avg_train_loss_ce = train_loss_ce_acc / train_steps avg_train_loss_surv = train_loss_surv_acc / train_steps train_losses_ce.append(avg_train_loss_ce) train_losses_surv.append(avg_train_loss_surv) train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv) # --- Validation Phase --- model.eval() val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 val_steps = 0 with torch.no_grad(): pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]") for event_seq, time_seq in pbar_val: event_seq, time_seq = event_seq.to(device), time_seq.to(device) input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) val_loss_ce_acc += loss_ce.item() val_loss_surv_acc += loss_survival.item() val_steps += 1 pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'}) avg_val_loss_ce = val_loss_ce_acc / val_steps avg_val_loss_surv = val_loss_surv_acc / val_steps total_val_loss = avg_val_loss_ce + avg_val_loss_surv val_losses_ce.append(avg_val_loss_ce) val_losses_surv.append(avg_val_loss_surv) val_losses_total.append(total_val_loss) print(f"Epoch {epoch+1} Summary: \n" f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" f" Learning Rate: {lr:.6f}") # --- Early Stopping Check --- if total_val_loss < best_val_loss: best_val_loss = total_val_loss patience_counter = 0 print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") # Save the underlying model state_dict when using DataParallel model_state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict() torch.save(model_state, 'best_model_checkpoint.pt') else: if epoch >= config.warmup_epochs: patience_counter += 1 print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}") if patience_counter >= config.early_stopping_patience: print("\nEarly stopping triggered due to no improvement in validation loss.") break # --- Save Best Model at the End --- if best_val_loss != float('inf'): print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") # Load the state dict into the base model, not the DataParallel wrapper base_model = model.module if isinstance(model, nn.DataParallel) else model base_model.load_state_dict(torch.load('best_model_checkpoint.pt')) print("Saving final best model to best_model.pt") torch.save(base_model.state_dict(), 'best_model.pt') else: print("\nTraining finished. No best model to save as validation loss never improved.") # --- Plot and Save Loss Curves --- num_epochs = len(train_losses_total) epochs = range(1, num_epochs + 1) plt.figure(figsize=(18, 5)) # Plot CE Loss plt.subplot(1, 3, 1) plt.plot(epochs, train_losses_ce, label='Train CE') plt.plot(epochs, val_losses_ce, label='Val CE') plt.title('Cross-Entropy Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.grid(True) # Plot Survival Loss plt.subplot(1, 3, 2) plt.plot(epochs, train_losses_surv, label='Train Survival') plt.plot(epochs, val_losses_surv, label='Val Survival') plt.title('Survival Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.grid(True) # Plot Total Loss plt.subplot(1, 3, 3) plt.plot(epochs, train_losses_total, label='Train Total') plt.plot(epochs, val_losses_total, label='Val Total') plt.title('Total Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.grid(True) plt.tight_layout() plt.savefig('loss_curves.png') print("\nLoss curves saved to loss_curves.png") if __name__ == '__main__': main()