From 5b0642eb6e6187b0e3c33b9c0343efba2808c3ba Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Wed, 22 Oct 2025 11:54:48 +0800 Subject: [PATCH] Add train_ddp.py: DistributedDataParallel multi-GPU training with distributed samplers, rank-0 checkpointing, and aggregated metrics --- train_ddp.py | 364 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 364 insertions(+) create mode 100644 train_ddp.py diff --git a/train_ddp.py b/train_ddp.py new file mode 100644 index 0000000..752a718 --- /dev/null +++ b/train_ddp.py @@ -0,0 +1,364 @@ +import os +import json +import math +import argparse +from typing import Tuple + +import torch +import torch.distributed as dist +from torch.optim import AdamW +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, DistributedSampler + +import numpy as np +import tqdm +import matplotlib.pyplot as plt + +from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss +from utils import PatientEventDataset + + +class TrainConfig: + # Data parameters + train_data_path = 'ukb_real_train.bin' + val_data_path = 'ukb_real_val.bin' + block_length = 48 + + # Model parameters + n_embd = 120 + n_layer = 12 + n_head = 12 + pdrop = 0.1 + token_pdrop = 0.1 + model_name = 'TimeAwareGPT2' + + # Training parameters + max_epoch = 200 + batch_size = 128 + lr_initial = 6e-4 + lr_final = 6e-5 + weight_decay = 2e-1 + warmup_epochs = 10 + early_stopping_patience = 10 + betas = (0.9, 0.99) + + # Loss parameters (ignored tokens) + ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + + +def setup_ddp(backend: str | None = None): + """Initialize torch.distributed from environment variables set by torchrun.""" + if backend is None: + if torch.cuda.is_available() and os.name != 'nt': + backend = 'nccl' + else: + backend = 'gloo' + dist.init_process_group(backend=backend) + + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + rank = int(os.environ.get('RANK', 0)) + world_size = int(os.environ.get('WORLD_SIZE', 1)) + + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + device = torch.device(f'cuda:{local_rank}') + else: + device = torch.device('cpu') + + return rank, world_size, local_rank, device + + +def cleanup_ddp(): + if dist.is_initialized(): + dist.destroy_process_group() + + +def cosine_lr(epoch: int, cfg: TrainConfig) -> float: + if epoch < cfg.warmup_epochs: + return cfg.lr_initial + progress = (epoch - cfg.warmup_epochs) / max(1, (cfg.max_epoch - cfg.warmup_epochs)) + return cfg.lr_final + 0.5 * (cfg.lr_initial - cfg.lr_final) * (1 + math.cos(math.pi * progress)) + + +def allreduce_avg(value: torch.Tensor, world_size: int) -> torch.Tensor: + """All-reduce sum then divide by world_size.""" + value = value.clone().to(torch.float64) + dist.all_reduce(value, op=dist.ReduceOp.SUM) + value /= world_size + return value.to(torch.float32) + + +def main(): + parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model (DDP). Use torchrun to launch.') + parser.add_argument('--n_layer', type=int, default=12) + parser.add_argument('--n_embd', type=int, default=120) + parser.add_argument('--n_head', type=int, default=12) + parser.add_argument('--max_epoch', type=int, default=200) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--lr_initial', type=float, default=6e-4) + parser.add_argument('--lr_final', type=float, default=6e-5) + parser.add_argument('--weight_decay', type=float, default=2e-1) + parser.add_argument('--warmup_epochs', type=int, default=10) + parser.add_argument('--early_stopping_patience', type=int, default=10) + parser.add_argument('--pdrop', type=float, default=0.1) + parser.add_argument('--token_pdrop', type=float, default=0.1) + parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99]) + parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2') + parser.add_argument('--backend', type=str, default=None, help='DDP backend (nccl/gloo). Default auto-selects.') + + args = parser.parse_args() + + rank, world_size, local_rank, device = setup_ddp(args.backend) + + # Build config + cfg = TrainConfig() + cfg.n_layer = args.n_layer + cfg.n_embd = args.n_embd + cfg.n_head = args.n_head + cfg.max_epoch = args.max_epoch + cfg.batch_size = args.batch_size + cfg.lr_initial = args.lr_initial + cfg.lr_final = args.lr_final + cfg.weight_decay = args.weight_decay + cfg.warmup_epochs = args.warmup_epochs + cfg.early_stopping_patience = args.early_stopping_patience + cfg.pdrop = args.pdrop + cfg.token_pdrop = args.token_pdrop + cfg.betas = tuple(args.betas) + cfg.model_name = args.model + + # Filenames (shared across ranks) + model_suffix = f"{cfg.model_name}_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}" + model_filename = f"best_model_{model_suffix}.pt" + checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt" + config_filename = f"config_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}.json" + + # Save config only on rank 0 + if rank == 0: + with open(config_filename, 'w') as f: + json.dump({k: v for k, v in vars(cfg).items() if not k.startswith('__')}, f, indent=4) + print(f"[rank 0] Configuration saved to {config_filename}") + + # Load data (all ranks) + if rank == 0: + print(f"Loading data from {cfg.train_data_path} and {cfg.val_data_path}...") + train_data_arr = np.memmap(cfg.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + val_data_arr = np.memmap(cfg.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + + vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 + if rank == 0: + print(f"Inferred vocabulary size: {vocab_size}") + + train_dataset = PatientEventDataset(train_data_arr, cfg.block_length) + val_dataset = PatientEventDataset(val_data_arr, cfg.block_length) + + train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) + val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False) + + train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=4, pin_memory=torch.cuda.is_available()) + val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, sampler=val_sampler, num_workers=4, pin_memory=torch.cuda.is_available()) + + # Model, loss, optimizer + model_cls = { + 'TimeAwareGPT2': TimeAwareGPT2, + 'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable, + }[cfg.model_name] + + model = model_cls( + vocab_size=vocab_size, + n_embd=cfg.n_embd, + n_layer=cfg.n_layer, + n_head=cfg.n_head, + pdrop=cfg.pdrop, + token_pdrop=cfg.token_pdrop, + ).to(device) + + ddp_model = DDP(model, device_ids=[local_rank] if torch.cuda.is_available() else None, output_device=local_rank if torch.cuda.is_available() else None) + + loss_fn = CombinedLoss(cfg.ignored_token_ids) + optimizer = AdamW(ddp_model.parameters(), lr=cfg.lr_initial, weight_decay=cfg.weight_decay, betas=cfg.betas) + + best_val_loss = float('inf') + patience_counter = 0 + + train_losses_ce, train_losses_surv, train_losses_total = [], [], [] + val_losses_ce, val_losses_surv, val_losses_total = [], [], [] + + if rank == 0: + print("Starting DDP training...") + + for epoch in range(cfg.max_epoch): + # Update sampler epoch for shuffling + train_sampler.set_epoch(epoch) + val_sampler.set_epoch(epoch) + + # Set LR + lr = cosine_lr(epoch, cfg) + for pg in optimizer.param_groups: + pg['lr'] = lr + + # Train + ddp_model.train() + train_loss_ce_acc = torch.zeros(1, device=device) + train_loss_surv_acc = torch.zeros(1, device=device) + train_steps = 0 + + pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Train]", disable=(rank != 0)) + for event_seq, time_seq in pbar: + event_seq = event_seq.to(device, non_blocking=True) + time_seq = time_seq.to(device, non_blocking=True) + + input_events = event_seq[:, :-1] + input_times = time_seq[:, :-1] + target_events = event_seq[:, 1:] + target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() + + logits = ddp_model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + loss = loss_ce + loss_survival + + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + + train_loss_ce_acc += loss_ce.detach() + train_loss_surv_acc += loss_survival.detach() + train_steps += 1 + + if rank == 0: + pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) + + # Aggregate train losses across ranks + if train_steps == 0: + train_steps = 1 + steps_tensor = torch.tensor([train_steps], device=device, dtype=torch.float64) + dist.all_reduce(steps_tensor, op=dist.ReduceOp.SUM) + train_loss_ce_mean = allreduce_avg(train_loss_ce_acc, world_size) / (steps_tensor.item() / world_size) + train_loss_surv_mean = allreduce_avg(train_loss_surv_acc, world_size) / (steps_tensor.item() / world_size) + + if rank == 0: + train_losses_ce.append(train_loss_ce_mean.item()) + train_losses_surv.append(train_loss_surv_mean.item()) + train_losses_total.append(train_loss_ce_mean.item() + train_loss_surv_mean.item()) + + # Validation + ddp_model.eval() + val_loss_ce_acc = torch.zeros(1, device=device) + val_loss_surv_acc = torch.zeros(1, device=device) + val_steps = 0 + + with torch.no_grad(): + pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Val]", disable=(rank != 0)) + for event_seq, time_seq in pbar_val: + event_seq = event_seq.to(device, non_blocking=True) + time_seq = time_seq.to(device, non_blocking=True) + + input_events = event_seq[:, :-1] + input_times = time_seq[:, :-1] + target_events = event_seq[:, 1:] + target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() + + logits = ddp_model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + + val_loss_ce_acc += loss_ce.detach() + val_loss_surv_acc += loss_survival.detach() + val_steps += 1 + + if val_steps == 0: + val_steps = 1 + vsteps_tensor = torch.tensor([val_steps], device=device, dtype=torch.float64) + dist.all_reduce(vsteps_tensor, op=dist.ReduceOp.SUM) + val_loss_ce_mean = allreduce_avg(val_loss_ce_acc, world_size) / (vsteps_tensor.item() / world_size) + val_loss_surv_mean = allreduce_avg(val_loss_surv_acc, world_size) / (vsteps_tensor.item() / world_size) + total_val_loss = (val_loss_ce_mean + val_loss_surv_mean).item() + + if rank == 0: + val_losses_ce.append(val_loss_ce_mean.item()) + val_losses_surv.append(val_loss_surv_mean.item()) + val_losses_total.append(total_val_loss) + + print( + f"Epoch {epoch+1} Summary:\n" + f" Train Loss: {train_losses_total[-1]:.4f} (CE: {train_losses_ce[-1]:.4f}, Surv: {train_losses_surv[-1]:.4f})\n" + f" Val Loss: {total_val_loss:.4f} (CE: {val_losses_ce[-1]:.4f}, Surv: {val_losses_surv[-1]:.4f})\n" + f" Learning Rate: {lr:.6f}" + ) + + # Early stopping on rank 0; broadcast decision + improved = total_val_loss < best_val_loss + if improved: + best_val_loss = total_val_loss + patience_counter = 0 + print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") + torch.save(ddp_model.module.state_dict(), checkpoint_filename) + else: + if epoch >= cfg.warmup_epochs: + patience_counter += 1 + print(f"Validation loss did not improve. Patience: {patience_counter}/{cfg.early_stopping_patience}") + + stop_flag = torch.tensor([1 if patience_counter >= cfg.early_stopping_patience else 0], device=device) + else: + stop_flag = torch.zeros(1, device=device) + + # Broadcast stop flag and best loss to all ranks + dist.broadcast(stop_flag, src=0) + if stop_flag.item() > 0: + if rank == 0: + print("\nEarly stopping triggered due to no improvement in validation loss.") + break + + # Save best model at the end (rank 0) + if rank == 0 and best_val_loss != float('inf'): + print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") + state = torch.load(checkpoint_filename, map_location='cpu') + ddp_model.module.load_state_dict(state) + print(f"Saving final best model to {model_filename}") + torch.save(ddp_model.module.state_dict(), model_filename) + + # Save losses to file + losses_filename = f"losses_{model_suffix}.txt" + with open(losses_filename, 'w') as f: + f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n") + for i in range(len(train_losses_total)): + f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n") + print(f"\nLosses saved to {losses_filename}") + + # Plot curves + num_epochs = len(train_losses_total) + epochs = range(1, num_epochs + 1) + plt.figure(figsize=(18, 5)) + + plt.subplot(1, 3, 1) + plt.plot(epochs, train_losses_ce, label='Train CE') + plt.plot(epochs, val_losses_ce, label='Val CE') + plt.title('Cross-Entropy Loss') + plt.xlabel('Epochs') + plt.ylabel('Loss') + plt.legend(); plt.grid(True) + + plt.subplot(1, 3, 2) + plt.plot(epochs, train_losses_surv, label='Train Survival') + plt.plot(epochs, val_losses_surv, label='Val Survival') + plt.title('Survival Loss') + plt.xlabel('Epochs') + plt.ylabel('Loss') + plt.legend(); plt.grid(True) + + plt.subplot(1, 3, 3) + plt.plot(epochs, train_losses_total, label='Train Total') + plt.plot(epochs, val_losses_total, label='Val Total') + plt.title('Total Loss') + plt.xlabel('Epochs') + plt.ylabel('Loss') + plt.legend(); plt.grid(True) + + plt.tight_layout() + plt.savefig('loss_curves.png') + print("\nLoss curves saved to loss_curves.png") + + cleanup_ddp() + + +if __name__ == '__main__': + main()