import os import json import math import argparse from typing import Tuple import torch import torch.distributed as dist from torch.optim import AdamW from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler import numpy as np import tqdm import matplotlib.pyplot as plt from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss from utils import PatientEventDataset class TrainConfig: # Data parameters train_data_path = 'ukb_real_train.bin' val_data_path = 'ukb_real_val.bin' block_length = 48 # Model parameters n_embd = 120 n_layer = 12 n_head = 12 pdrop = 0.1 token_pdrop = 0.1 model_name = 'TimeAwareGPT2' # Training parameters max_epoch = 200 batch_size = 128 lr_initial = 6e-4 lr_final = 6e-5 weight_decay = 2e-1 warmup_epochs = 10 early_stopping_patience = 10 betas = (0.9, 0.99) # Loss parameters (ignored tokens) ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] def setup_ddp(backend: str | None = None): """Initialize torch.distributed from environment variables set by torchrun.""" if backend is None: if torch.cuda.is_available() and os.name != 'nt': backend = 'nccl' else: backend = 'gloo' dist.init_process_group(backend=backend) local_rank = int(os.environ.get('LOCAL_RANK', 0)) rank = int(os.environ.get('RANK', 0)) world_size = int(os.environ.get('WORLD_SIZE', 1)) if torch.cuda.is_available(): torch.cuda.set_device(local_rank) device = torch.device(f'cuda:{local_rank}') else: device = torch.device('cpu') return rank, world_size, local_rank, device def cleanup_ddp(): if dist.is_initialized(): dist.destroy_process_group() def cosine_lr(epoch: int, cfg: TrainConfig) -> float: if epoch < cfg.warmup_epochs: return cfg.lr_initial progress = (epoch - cfg.warmup_epochs) / max(1, (cfg.max_epoch - cfg.warmup_epochs)) return cfg.lr_final + 0.5 * (cfg.lr_initial - cfg.lr_final) * (1 + math.cos(math.pi * progress)) def allreduce_avg(value: torch.Tensor, world_size: int) -> torch.Tensor: """All-reduce sum then divide by world_size.""" value = value.clone().to(torch.float64) dist.all_reduce(value, op=dist.ReduceOp.SUM) value /= world_size return value.to(torch.float32) def main(): parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model (DDP). Use torchrun to launch.') parser.add_argument('--n_layer', type=int, default=12) parser.add_argument('--n_embd', type=int, default=120) parser.add_argument('--n_head', type=int, default=12) parser.add_argument('--max_epoch', type=int, default=200) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--lr_initial', type=float, default=6e-4) parser.add_argument('--lr_final', type=float, default=6e-5) parser.add_argument('--weight_decay', type=float, default=2e-1) parser.add_argument('--warmup_epochs', type=int, default=10) parser.add_argument('--early_stopping_patience', type=int, default=10) parser.add_argument('--pdrop', type=float, default=0.1) parser.add_argument('--token_pdrop', type=float, default=0.1) parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99]) parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2') parser.add_argument('--backend', type=str, default=None, help='DDP backend (nccl/gloo). Default auto-selects.') args = parser.parse_args() rank, world_size, local_rank, device = setup_ddp(args.backend) # Build config cfg = TrainConfig() cfg.n_layer = args.n_layer cfg.n_embd = args.n_embd cfg.n_head = args.n_head cfg.max_epoch = args.max_epoch cfg.batch_size = args.batch_size cfg.lr_initial = args.lr_initial cfg.lr_final = args.lr_final cfg.weight_decay = args.weight_decay cfg.warmup_epochs = args.warmup_epochs cfg.early_stopping_patience = args.early_stopping_patience cfg.pdrop = args.pdrop cfg.token_pdrop = args.token_pdrop cfg.betas = tuple(args.betas) cfg.model_name = args.model # Filenames (shared across ranks) model_suffix = f"{cfg.model_name}_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}" model_filename = f"best_model_{model_suffix}.pt" checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt" config_filename = f"config_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}.json" # Save config only on rank 0 if rank == 0: with open(config_filename, 'w') as f: json.dump({k: v for k, v in vars(cfg).items() if not k.startswith('__')}, f, indent=4) print(f"[rank 0] Configuration saved to {config_filename}") # Load data (all ranks) if rank == 0: print(f"Loading data from {cfg.train_data_path} and {cfg.val_data_path}...") train_data_arr = np.memmap(cfg.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) val_data_arr = np.memmap(cfg.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 if rank == 0: print(f"Inferred vocabulary size: {vocab_size}") train_dataset = PatientEventDataset(train_data_arr, cfg.block_length) val_dataset = PatientEventDataset(val_data_arr, cfg.block_length) train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False) train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=4, pin_memory=torch.cuda.is_available()) val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, sampler=val_sampler, num_workers=4, pin_memory=torch.cuda.is_available()) # Model, loss, optimizer model_cls = { 'TimeAwareGPT2': TimeAwareGPT2, 'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, }[cfg.model_name] model = model_cls( vocab_size=vocab_size, n_embd=cfg.n_embd, n_layer=cfg.n_layer, n_head=cfg.n_head, pdrop=cfg.pdrop, token_pdrop=cfg.token_pdrop, ).to(device) ddp_model = DDP(model, device_ids=[local_rank] if torch.cuda.is_available() else None, output_device=local_rank if torch.cuda.is_available() else None) loss_fn = CombinedLoss(cfg.ignored_token_ids) optimizer = AdamW(ddp_model.parameters(), lr=cfg.lr_initial, weight_decay=cfg.weight_decay, betas=cfg.betas) best_val_loss = float('inf') patience_counter = 0 train_losses_ce, train_losses_surv, train_losses_total = [], [], [] val_losses_ce, val_losses_surv, val_losses_total = [], [], [] if rank == 0: print("Starting DDP training...") for epoch in range(cfg.max_epoch): # Update sampler epoch for shuffling train_sampler.set_epoch(epoch) val_sampler.set_epoch(epoch) # Set LR lr = cosine_lr(epoch, cfg) for pg in optimizer.param_groups: pg['lr'] = lr # Train ddp_model.train() train_loss_ce_acc = torch.zeros(1, device=device) train_loss_surv_acc = torch.zeros(1, device=device) train_steps = 0 pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Train]", disable=(rank != 0)) for event_seq, time_seq in pbar: event_seq = event_seq.to(device, non_blocking=True) time_seq = time_seq.to(device, non_blocking=True) input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() logits = ddp_model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss = loss_ce + loss_survival optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() train_loss_ce_acc += loss_ce.detach() train_loss_surv_acc += loss_survival.detach() train_steps += 1 if rank == 0: pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) # Aggregate train losses across ranks if train_steps == 0: train_steps = 1 steps_tensor = torch.tensor([train_steps], device=device, dtype=torch.float64) dist.all_reduce(steps_tensor, op=dist.ReduceOp.SUM) train_loss_ce_mean = allreduce_avg(train_loss_ce_acc, world_size) / (steps_tensor.item() / world_size) train_loss_surv_mean = allreduce_avg(train_loss_surv_acc, world_size) / (steps_tensor.item() / world_size) if rank == 0: train_losses_ce.append(train_loss_ce_mean.item()) train_losses_surv.append(train_loss_surv_mean.item()) train_losses_total.append(train_loss_ce_mean.item() + train_loss_surv_mean.item()) # Validation ddp_model.eval() val_loss_ce_acc = torch.zeros(1, device=device) val_loss_surv_acc = torch.zeros(1, device=device) val_steps = 0 with torch.no_grad(): pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Val]", disable=(rank != 0)) for event_seq, time_seq in pbar_val: event_seq = event_seq.to(device, non_blocking=True) time_seq = time_seq.to(device, non_blocking=True) input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() logits = ddp_model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) val_loss_ce_acc += loss_ce.detach() val_loss_surv_acc += loss_survival.detach() val_steps += 1 if val_steps == 0: val_steps = 1 vsteps_tensor = torch.tensor([val_steps], device=device, dtype=torch.float64) dist.all_reduce(vsteps_tensor, op=dist.ReduceOp.SUM) val_loss_ce_mean = allreduce_avg(val_loss_ce_acc, world_size) / (vsteps_tensor.item() / world_size) val_loss_surv_mean = allreduce_avg(val_loss_surv_acc, world_size) / (vsteps_tensor.item() / world_size) total_val_loss = (val_loss_ce_mean + val_loss_surv_mean).item() if rank == 0: val_losses_ce.append(val_loss_ce_mean.item()) val_losses_surv.append(val_loss_surv_mean.item()) val_losses_total.append(total_val_loss) print( f"Epoch {epoch+1} Summary:\n" f" Train Loss: {train_losses_total[-1]:.4f} (CE: {train_losses_ce[-1]:.4f}, Surv: {train_losses_surv[-1]:.4f})\n" f" Val Loss: {total_val_loss:.4f} (CE: {val_losses_ce[-1]:.4f}, Surv: {val_losses_surv[-1]:.4f})\n" f" Learning Rate: {lr:.6f}" ) # Early stopping on rank 0; broadcast decision improved = total_val_loss < best_val_loss if improved: best_val_loss = total_val_loss patience_counter = 0 print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") torch.save(ddp_model.module.state_dict(), checkpoint_filename) else: if epoch >= cfg.warmup_epochs: patience_counter += 1 print(f"Validation loss did not improve. Patience: {patience_counter}/{cfg.early_stopping_patience}") stop_flag = torch.tensor([1 if patience_counter >= cfg.early_stopping_patience else 0], device=device) else: stop_flag = torch.zeros(1, device=device) # Broadcast stop flag and best loss to all ranks dist.broadcast(stop_flag, src=0) if stop_flag.item() > 0: if rank == 0: print("\nEarly stopping triggered due to no improvement in validation loss.") break # Save best model at the end (rank 0) if rank == 0 and best_val_loss != float('inf'): print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") state = torch.load(checkpoint_filename, map_location='cpu') ddp_model.module.load_state_dict(state) print(f"Saving final best model to {model_filename}") torch.save(ddp_model.module.state_dict(), model_filename) # Save losses to file losses_filename = f"losses_{model_suffix}.txt" with open(losses_filename, 'w') as f: f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n") for i in range(len(train_losses_total)): f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n") print(f"\nLosses saved to {losses_filename}") # Plot curves num_epochs = len(train_losses_total) epochs = range(1, num_epochs + 1) plt.figure(figsize=(18, 5)) plt.subplot(1, 3, 1) plt.plot(epochs, train_losses_ce, label='Train CE') plt.plot(epochs, val_losses_ce, label='Val CE') plt.title('Cross-Entropy Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend(); plt.grid(True) plt.subplot(1, 3, 2) plt.plot(epochs, train_losses_surv, label='Train Survival') plt.plot(epochs, val_losses_surv, label='Val Survival') plt.title('Survival Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend(); plt.grid(True) plt.subplot(1, 3, 3) plt.plot(epochs, train_losses_total, label='Train Total') plt.plot(epochs, val_losses_total, label='Val Total') plt.title('Total Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend(); plt.grid(True) plt.tight_layout() plt.savefig('loss_curves.png') print("\nLoss curves saved to loss_curves.png") cleanup_ddp() if __name__ == '__main__': main()