import torch import torch.nn as nn from torch.optim import AdamW from torch.utils.data import DataLoader import numpy as np import math import tqdm import matplotlib.pyplot as plt import json import os import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler from models import TimeAwareGPT2, CombinedLoss from utils import PatientEventDataset # --- Configuration --- class TrainConfig: # Data parameters train_data_path = 'ukb_real_train.bin' val_data_path = 'ukb_real_val.bin' block_length = 48 # Sequence length # Model parameters n_embd = 120 n_layer = 12 n_head = 12 pdrop = 0.1 token_pdrop = 0.1 # Training parameters max_epoch = 200 batch_size = 128 lr_initial = 6e-4 lr_final = 6e-5 weight_decay = 2e-1 warmup_epochs = 10 early_stopping_patience = 10 # Loss parameters ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # System parameters (will be updated by DDP setup) device = 'cuda' # --- DDP Setup --- def setup_ddp(): """Initializes the distributed data parallel environment.""" dist.init_process_group(backend='nccl') rank = dist.get_rank() local_rank = int(os.environ['LOCAL_RANK']) torch.cuda.set_device(local_rank) return rank, local_rank def cleanup_ddp(): """Cleans up the distributed data parallel environment.""" dist.destroy_process_group() # --- Main Training Script --- def main(): rank, local_rank = setup_ddp() is_main_process = (rank == 0) config = TrainConfig() config.device = f'cuda:{local_rank}' if is_main_process: model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" # --- 0. Save Configuration --- config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json" config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} with open(config_filename, 'w') as f: json.dump(config_dict, f, indent=4) print(f"Configuration saved to {config_filename}") # --- 1. Data Loading --- if is_main_process: print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 if is_main_process: print(f"Inferred vocabulary size: {vocab_size}") train_dataset = PatientEventDataset(train_data_arr, config.block_length) val_dataset = PatientEventDataset(val_data_arr, config.block_length) train_sampler = DistributedSampler(train_dataset) val_sampler = DistributedSampler(val_dataset, shuffle=False) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=config.batch_size, sampler=val_sampler, num_workers=4, pin_memory=True) # --- 2. Model, Optimizer, and Loss Initialization --- if is_main_process: print(f"Initializing model on {config.device}...") model = TimeAwareGPT2( vocab_size=vocab_size, n_embd=config.n_embd, n_layer=config.n_layer, n_head=config.n_head, pdrop=config.pdrop, token_pdrop=config.token_pdrop ).to(config.device) model = DDP(model, device_ids=[local_rank]) if is_main_process: print(f"Model initialized with {model.module.get_num_params():.2f}M trainable parameters.") loss_fn = CombinedLoss(config.ignored_token_ids) optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay) # --- 3. Training Loop --- best_val_loss = float('inf') patience_counter = 0 train_losses_ce, train_losses_surv, train_losses_total = [], [], [] val_losses_ce, val_losses_surv, val_losses_total = [], [], [] if is_main_process: print("Starting training...") for epoch in range(config.max_epoch): train_sampler.set_epoch(epoch) # Important for shuffling if epoch < config.warmup_epochs: lr = config.lr_initial else: progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) for param_group in optimizer.param_groups: param_group['lr'] = lr model.train() train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0 train_steps = 0 pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]", disable=not is_main_process) for event_seq, time_seq in pbar: event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) input_events, input_times = event_seq[:, :-1], time_seq[:, :-1] target_events, target_wait_times = event_seq[:, 1:], (time_seq[:, 1:] - time_seq[:, :-1]).float() logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss = loss_ce + loss_survival optimizer.zero_grad() loss.backward() optimizer.step() train_loss_ce_acc += loss_ce.item() train_loss_surv_acc += loss_survival.item() train_steps += 1 if is_main_process: pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) avg_train_loss_ce = train_loss_ce_acc / train_steps avg_train_loss_surv = train_loss_surv_acc / train_steps train_losses_ce.append(avg_train_loss_ce) train_losses_surv.append(avg_train_loss_surv) train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv) model.eval() val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 val_steps = 0 with torch.no_grad(): pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]", disable=not is_main_process) for event_seq, time_seq in pbar_val: event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) input_events, input_times = event_seq[:, :-1], time_seq[:, :-1] target_events, target_wait_times = event_seq[:, 1:], (time_seq[:, 1:] - time_seq[:, :-1]).float() logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) val_loss_ce_acc += loss_ce.item() val_loss_surv_acc += loss_survival.item() val_steps += 1 if is_main_process: pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'}) avg_val_loss_ce = val_loss_ce_acc / val_steps avg_val_loss_surv = val_loss_surv_acc / val_steps total_val_loss = avg_val_loss_ce + avg_val_loss_surv val_losses_ce.append(avg_val_loss_ce) val_losses_surv.append(avg_val_loss_surv) val_losses_total.append(total_val_loss) if is_main_process: print(f"Epoch {epoch+1} Summary: \n" f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" f" Learning Rate: {lr:.6f}") if total_val_loss < best_val_loss: best_val_loss = total_val_loss patience_counter = 0 print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") torch.save(model.module.state_dict(), checkpoint_filename) else: if epoch >= config.warmup_epochs: patience_counter += 1 print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}") if patience_counter >= config.early_stopping_patience: if is_main_process: print("\nEarly stopping triggered due to no improvement in validation loss.") break if is_main_process: if best_val_loss != float('inf'): print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") # Load the best weights into the module before saving the final model model.module.load_state_dict(torch.load(checkpoint_filename)) print(f"Saving final best model to {model_filename}") torch.save(model.module.state_dict(), model_filename) else: print("\nTraining finished. No best model to save as validation loss never improved.") losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt" with open(losses_filename, 'w') as f: f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n") for i in range(len(train_losses_total)): f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n") print(f"\nLosses saved to {losses_filename}") num_epochs = len(train_losses_total) epochs = range(1, num_epochs + 1) plt.figure(figsize=(18, 5)) plt.subplot(1, 3, 1) plt.plot(epochs, train_losses_ce, label='Train CE') plt.plot(epochs, val_losses_ce, label='Val CE') plt.title('Cross-Entropy Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.grid(True) plt.subplot(1, 3, 2) plt.plot(epochs, train_losses_surv, label='Train Survival') plt.plot(epochs, val_losses_surv, label='Val Survival') plt.title('Survival Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.grid(True) plt.subplot(1, 3, 3) plt.plot(epochs, train_losses_total, label='Train Total') plt.plot(epochs, val_losses_total, label='Val Total') plt.title('Total Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.grid(True) plt.tight_layout() plt.savefig('loss_curves.png') print("\nLoss curves saved to loss_curves.png") cleanup_ddp() if __name__ == '__main__': main()