import torch import torch.nn as nn from torch.optim import AdamW from torch.utils.data import DataLoader import numpy as np import math import tqdm import matplotlib.pyplot as plt import json import itertools import os import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler from models import TimeAwareGPT2, CombinedLoss from utils import PatientEventDataset # --- Configuration --- class TrainConfig: # Data parameters train_data_path = 'ukb_real_train.bin' val_data_path = 'ukb_real_val.bin' block_length = 48 # Sequence length # Model parameters n_embd = 120 n_layer = 12 n_head = 12 pdrop = 0.1 token_pdrop = 0.1 # Training parameters max_iter = 200000 batch_size = 128 # Per GPU lr_initial = 6e-4 lr_final = 6e-5 weight_decay = 2e-1 warmup_iter = 1000 # Loss parameters # 0 = padding, 1 = "no event" ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs # System parameters device = 'cuda' # --- DDP Setup --- def setup_ddp(): """Initializes the distributed data parallel environment.""" dist.init_process_group(backend='nccl') rank = dist.get_rank() local_rank = int(os.environ['LOCAL_RANK']) torch.cuda.set_device(local_rank) return rank, local_rank def cleanup_ddp(): """Cleans up the distributed data parallel environment.""" dist.destroy_process_group() # --- Main Training Script --- def main(): rank, local_rank = setup_ddp() is_main_process = (rank == 0) config = TrainConfig() config.device = f'cuda:{local_rank}' if is_main_process: model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter_multigpu.pt" # --- 0. Save Configuration --- config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter_multigpu.json" config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} with open(config_filename, 'w') as f: json.dump(config_dict, f, indent=4) print(f"Configuration saved to {config_filename}") # --- 1. Data Loading --- if is_main_process: print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 if is_main_process: print(f"Inferred vocabulary size: {vocab_size}") train_dataset = PatientEventDataset(train_data_arr, config.block_length) val_dataset = PatientEventDataset(val_data_arr, config.block_length) train_sampler = DistributedSampler(train_dataset) val_sampler = DistributedSampler(val_dataset, shuffle=False) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, sampler=train_sampler, num_workers=4, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size=config.batch_size, sampler=val_sampler, num_workers=4, pin_memory=True) train_iter_loader = iter(itertools.cycle(train_loader)) # --- 2. Model, Optimizer, and Loss Initialization --- if is_main_process: print(f"Initializing model on {config.device}...") model = TimeAwareGPT2( vocab_size=vocab_size, n_embd=config.n_embd, n_layer=config.n_layer, n_head=config.n_head, pdrop=config.pdrop, token_pdrop=config.token_pdrop ).to(config.device) model = DDP(model, device_ids=[local_rank]) if is_main_process: print(f"Model initialized with {model.module.get_num_params():.2f}M trainable parameters.") loss_fn = CombinedLoss(config.ignored_token_ids) optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay) # --- 3. Training Loop --- train_losses_ce, train_losses_surv, train_losses_total = [], [], [] if is_main_process: print("Starting training...") pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training", disable=not is_main_process) for iter_num in pbar: # --- Learning Rate Scheduling --- if iter_num < config.warmup_iter: lr = config.lr_initial else: progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter) lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) for param_group in optimizer.param_groups: param_group['lr'] = lr # --- Training Step --- model.train() event_seq, time_seq = next(train_iter_loader) event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss = loss_ce + loss_survival optimizer.zero_grad() loss.backward() optimizer.step() if is_main_process: train_losses_ce.append(loss_ce.item()) train_losses_surv.append(loss_survival.item()) train_losses_total.append(loss.item()) pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) if is_main_process: print("\nTraining finished.") # --- 4. Final Validation --- if is_main_process: print("Running final validation...") model.eval() val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 val_steps = 0 with torch.no_grad(): pbar_val = tqdm.tqdm(val_loader, desc="Final Validation", disable=not is_main_process) for event_seq, time_seq in pbar_val: event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) val_loss_ce_acc += loss_ce.item() val_loss_surv_acc += loss_survival.item() val_steps += 1 if is_main_process: pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'}) avg_val_loss_ce = val_loss_ce_acc / val_steps avg_val_loss_surv = val_loss_surv_acc / val_steps total_val_loss = avg_val_loss_ce + avg_val_loss_surv if is_main_process: print(f"Final Validation Summary: \n" f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})") # --- 5. Save Model --- print(f"Saving final model to {model_filename}") torch.save(model.module.state_dict(), model_filename) # --- 6. Save and Plot Losses --- losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter_multigpu.txt" with open(losses_filename, 'w') as f: f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n") for i in range(len(train_losses_total)): f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n") print(f"\nLosses saved to {losses_filename}") # Plot and Save Loss Curves iterations = range(1, len(train_losses_total) + 1) plt.figure(figsize=(18, 5)) # Plot CE Loss plt.subplot(1, 3, 1) plt.plot(iterations, train_losses_ce, label='Train CE') plt.title('Cross-Entropy Loss') plt.xlabel('Iterations') plt.ylabel('Loss') plt.legend() plt.grid(True) # Plot Survival Loss plt.subplot(1, 3, 2) plt.plot(iterations, train_losses_surv, label='Train Survival') plt.title('Survival Loss') plt.xlabel('Iterations') plt.ylabel('Loss') plt.legend() plt.grid(True) # Plot Total Loss plt.subplot(1, 3, 3) plt.plot(iterations, train_losses_total, label='Train Total') plt.title('Total Loss') plt.xlabel('Iterations') plt.ylabel('Loss') plt.legend() plt.grid(True) plt.tight_layout() plt.savefig('loss_curves_iter_multigpu.png') print("\nLoss curves saved to loss_curves_iter_multigpu.png") cleanup_ddp() if __name__ == '__main__': main()