# train_ddp.py import os import math import numpy as np import tqdm import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Adam from torch.utils.data import DataLoader, DistributedSampler from models import TimeAwareGPT2, CombinedLoss from utils import PatientEventDataset # --- Configuration --- class TrainConfig: # Data parameters train_data_path = 'ukb_real_train.bin' val_data_path = 'ukb_real_val.bin' block_length = 24 # Sequence length # Model parameters n_embd = 256 n_layer = 8 n_head = 8 pdrop = 0.1 token_pdrop = 0.1 # Training parameters max_epoch = 200 batch_size = 128 lr_initial = 6e-4 lr_final = 6e-5 warmup_epochs = 10 early_stopping_patience = 5 # Loss parameters # 0 = padding, 1 = "no event" ignored_token_ids = [0, 1] # System parameters device = 'cuda' if torch.cuda.is_available() else 'cpu' def ddp_is_active(): return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 def ddp_setup(): """Initialize process group if launched by torchrun.""" if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: backend = 'nccl' if torch.cuda.is_available() else 'gloo' dist.init_process_group(backend=backend, init_method='env://') return True return False def ddp_cleanup(): if dist.is_initialized(): dist.barrier() dist.destroy_process_group() def all_reduce_mean(value: torch.Tensor): """Mean across processes (no-op if not DDP).""" if ddp_is_active(): dist.all_reduce(value, op=dist.ReduceOp.SUM) value /= dist.get_world_size() return value def bcast_bool(stop: bool): if not ddp_is_active(): return stop t = torch.tensor(1 if stop else 0, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) dist.broadcast(t, src=0) return bool(t.item()) def main(): config = TrainConfig() # --- DDP setup --- ddp_enabled = ddp_setup() local_rank = int(os.environ.get("LOCAL_RANK", 0)) global_rank = int(os.environ.get("RANK", 0)) if ddp_enabled else 0 world_size = int(os.environ.get("WORLD_SIZE", 1)) if ddp_enabled else 1 if torch.cuda.is_available(): torch.cuda.set_device(local_rank) device = torch.device('cuda', local_rank) else: device = torch.device(config.device) # Seed per-rank (different but deterministic) torch.manual_seed(1337 + global_rank) np.random.seed(1337 + global_rank) is_main = (global_rank == 0) if is_main: print(f"DDP enabled: {ddp_enabled} | world_size={world_size} | device={device}") # --- 1. Data Loading --- if is_main: print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) # Infer vocab_size from data (max label + 1) vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 if is_main: print(f"Inferred vocabulary size: {vocab_size}") train_dataset = PatientEventDataset(train_data_arr, config.block_length) val_dataset = PatientEventDataset(val_data_arr, config.block_length) # DDP samplers (fall back to regular shuffle when not DDP) if ddp_enabled: train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True, drop_last=False) val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=global_rank, shuffle=False, drop_last=False) shuffle_flag = False else: train_sampler, val_sampler = None, None shuffle_flag = True train_loader = DataLoader( train_dataset, batch_size=config.batch_size, shuffle=shuffle_flag, sampler=train_sampler, num_workers=4, pin_memory=True, drop_last=False, ) val_loader = DataLoader( val_dataset, batch_size=config.batch_size, shuffle=False, sampler=val_sampler, num_workers=4, pin_memory=True, drop_last=False, ) # --- 2. Model, Optimizer, Loss --- if is_main: print(f"Initializing model on {device}...") model = TimeAwareGPT2( vocab_size=vocab_size, n_embd=config.n_embd, n_layer=config.n_layer, n_head=config.n_head, pdrop=config.pdrop, token_pdrop=config.token_pdrop ).to(device) if is_main: # If your model has get_num_params try: print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") except Exception: total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model initialized with {total_params/1e6:.2f}M trainable parameters.") loss_fn = CombinedLoss(config.ignored_token_ids) optimizer = Adam(model.parameters(), lr=config.lr_initial) # Wrap with DDP if ddp_enabled: model = DDP(model, device_ids=[local_rank] if device.type == 'cuda' else None, output_device=local_rank if device.type == 'cuda' else None, find_unused_parameters=False) # AMP use_amp = (device.type == 'cuda') scaler = torch.cuda.amp.GradScaler(enabled=use_amp) # --- 3. Training Loop --- best_val_loss = float('inf') patience_counter = 0 # store losses only on main to plot train_losses_ce, train_losses_surv, train_losses_total = [], [], [] val_losses_ce, val_losses_surv, val_losses_total = [], [], [] if is_main: print("Starting training...") for epoch in range(config.max_epoch): # Ensure different shuffles per epoch under DDP if ddp_enabled: train_sampler.set_epoch(epoch) val_sampler.set_epoch(epoch) # --- LR scheduling (same as original) --- if epoch < config.warmup_epochs: lr = config.lr_initial else: progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) for pg in optimizer.param_groups: pg['lr'] = lr # --- Training --- if is_main: pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") else: pbar = train_loader # silent on non-main ranks model.train() train_ce_sum = torch.tensor(0.0, device=device) train_surv_sum = torch.tensor(0.0, device=device) train_steps = 0 for batch in pbar: event_seq, time_seq = batch event_seq = event_seq.to(device, non_blocking=True) time_seq = time_seq.to(device, non_blocking=True) input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() optimizer.zero_grad(set_to_none=True) with torch.cuda.amp.autocast(enabled=use_amp): logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss = loss_ce + loss_survival if use_amp: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: loss.backward() optimizer.step() train_ce_sum += loss_ce.detach() train_surv_sum += loss_survival.detach() train_steps += 1 if is_main and isinstance(pbar, tqdm.tqdm): pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) # Average train losses across ranks train_ce_avg = train_ce_sum / max(1, train_steps) train_surv_avg = train_surv_sum / max(1, train_steps) train_ce_avg = all_reduce_mean(train_ce_avg) train_surv_avg = all_reduce_mean(train_surv_avg) train_total_avg = train_ce_avg + train_surv_avg # --- Validation --- if is_main: pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]") else: pbar_val = val_loader model.eval() val_ce_sum = torch.tensor(0.0, device=device) val_surv_sum = torch.tensor(0.0, device=device) val_steps = 0 with torch.no_grad(): for batch in pbar_val: event_seq, time_seq = batch event_seq = event_seq.to(device, non_blocking=True) time_seq = time_seq.to(device, non_blocking=True) input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() with torch.cuda.amp.autocast(enabled=use_amp): logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) val_ce_sum += loss_ce.detach() val_surv_sum += loss_survival.detach() val_steps += 1 if is_main and isinstance(pbar_val, tqdm.tqdm): pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'}) # Average val losses across ranks val_ce_avg = val_ce_sum / max(1, val_steps) val_surv_avg = val_surv_sum / max(1, val_steps) val_ce_avg = all_reduce_mean(val_ce_avg) val_surv_avg = all_reduce_mean(val_surv_avg) val_total_avg = val_ce_avg + val_surv_avg # --- Logging & Early Stopping (rank 0) --- stop_now = False if is_main: print(f"Epoch {epoch+1} Summary:\n" f" Train Loss: {float(train_total_avg):.4f} (CE: {float(train_ce_avg):.4f}, Surv: {float(train_surv_avg):.4f})\n" f" Val Loss: {float(val_total_avg):.4f} (CE: {float(val_ce_avg):.4f}, Surv: {float(val_surv_avg):.4f})\n" f" Learning Rate: {lr:.6f}") # Record for curves train_losses_ce.append(float(train_ce_avg)) train_losses_surv.append(float(train_surv_avg)) train_losses_total.append(float(train_total_avg)) val_losses_ce.append(float(val_ce_avg)) val_losses_surv.append(float(val_surv_avg)) val_losses_total.append(float(val_total_avg)) if val_total_avg < best_val_loss: best_val_loss = float(val_total_avg) patience_counter = 0 print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint (rank0)...") # unwrap DDP to_save = model.module if isinstance(model, DDP) else model torch.save(to_save.state_dict(), 'best_model_checkpoint.pt') else: if epoch >= config.warmup_epochs: patience_counter += 1 print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}") if patience_counter >= config.early_stopping_patience: print("\nEarly stopping triggered due to no improvement in validation loss.") stop_now = True # Broadcast stop flag so all ranks exit together stop_now = bcast_bool(stop_now) if stop_now: break # --- Save Best Model at the End (rank 0 only) --- if is_main: if best_val_loss != float('inf'): print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") # unwrap for loading & final save to_load = model.module if isinstance(model, DDP) else model to_load.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu')) print("Saving final best model to best_model.pt") torch.save(to_load.state_dict(), 'best_model.pt') else: print("\nTraining finished. No best model to save as validation loss never improved.") # --- Plot and Save Loss Curves --- num_epochs = len(train_losses_total) if num_epochs > 0: epochs = range(1, num_epochs + 1) plt.figure(figsize=(18, 5)) plt.subplot(1, 3, 1) plt.plot(epochs, train_losses_ce, label='Train CE') plt.plot(epochs, val_losses_ce, label='Val CE') plt.title('Cross-Entropy Loss') plt.xlabel('Epochs'); plt.ylabel('Loss') plt.legend(); plt.grid(True) plt.subplot(1, 3, 2) plt.plot(epochs, train_losses_surv, label='Train Survival') plt.plot(epochs, val_losses_surv, label='Val Survival') plt.title('Survival Loss') plt.xlabel('Epochs'); plt.ylabel('Loss') plt.legend(); plt.grid(True) plt.subplot(1, 3, 3) plt.plot(epochs, train_losses_total, label='Train Total') plt.plot(epochs, val_losses_total, label='Val Total') plt.title('Total Loss') plt.xlabel('Epochs'); plt.ylabel('Loss') plt.legend(); plt.grid(True) plt.tight_layout() plt.savefig('loss_curves.png') print("\nLoss curves saved to loss_curves.png") # Clean up DDP ddp_cleanup() if __name__ == '__main__': main()