# train.py (DDP-ready) import os import math import argparse import numpy as np import tqdm import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.distributed as dist from torch.optim import Adam from torch.utils.data import DataLoader, DistributedSampler from models import TimeAwareGPT2, CombinedLoss from utils import PatientEventDataset # --- Configuration --- class TrainConfig: # Data parameters train_data_path = 'ukb_real_train.bin' val_data_path = 'ukb_real_val.bin' block_length = 24 # Sequence length # Model parameters n_embd = 256 n_layer = 8 n_head = 8 pdrop = 0.1 token_pdrop = 0.1 # Training parameters max_epoch = 200 batch_size = 128 lr_initial = 6e-4 lr_final = 6e-5 warmup_epochs = 10 early_stopping_patience = 5 # Loss parameters # 0 = padding, 1 = "no event" ignored_token_ids = [0, 1] # System parameters (device 将在 main() 内按 local_rank 动态设置) device = 'cuda' if torch.cuda.is_available() else 'cpu' def setup_distributed(backend: str = "nccl"): """ 如果由 torchrun 启动且 WORLD_SIZE>1,则初始化分布式。 返回 (is_distributed, world_size, rank, local_rank) """ world_size = int(os.environ.get("WORLD_SIZE", "1")) is_distributed = world_size > 1 if is_distributed: if not dist.is_initialized(): dist.init_process_group(backend=backend, init_method="env://") rank = dist.get_rank() local_rank = int(os.environ.get("LOCAL_RANK", "0")) torch.cuda.set_device(local_rank) else: rank = 0 local_rank = 0 return is_distributed, world_size, rank, local_rank def cleanup_distributed(): if dist.is_available() and dist.is_initialized(): dist.destroy_process_group() def all_reduce_mean(value: float, device, world_size: int): """ value 是 Python float(本进程的和/均值),返回所有进程平均后的 float。 """ tensor = torch.tensor([value], dtype=torch.float32, device=device) dist.all_reduce(tensor, op=dist.ReduceOp.SUM) tensor /= world_size return float(tensor.item()) def main(): parser = argparse.ArgumentParser() parser.add_argument("--backend", type=str, default="nccl", choices=["nccl", "gloo", "mpi"]) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() # 分布式初始化 is_dist, world_size, rank, local_rank = setup_distributed(args.backend) # 基本环境 torch.manual_seed(args.seed + rank) np.random.seed(args.seed + rank) torch.backends.cudnn.benchmark = True config = TrainConfig() device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") config.device = device is_main = (rank == 0) # --- 1. Data Loading --- if is_main: print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) # Infer vocab_size from the data (max label + 1) vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 if is_main: print(f"Inferred vocabulary size: {vocab_size}") train_dataset = PatientEventDataset(train_data_arr, config.block_length) val_dataset = PatientEventDataset(val_data_arr, config.block_length) # 分布式采样器 if is_dist: train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False) else: train_sampler = None val_sampler = None train_loader = DataLoader( train_dataset, batch_size=config.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=4, pin_memory=True, drop_last=False, persistent_workers=True if 4 > 0 else False, ) val_loader = DataLoader( val_dataset, batch_size=config.batch_size, shuffle=False, sampler=val_sampler, num_workers=4, pin_memory=True, drop_last=False, persistent_workers=True if 4 > 0 else False, ) # --- 2. Model, Optimizer, and Loss Initialization --- if is_main: print(f"Initializing model on {config.device}...") model = TimeAwareGPT2( vocab_size=vocab_size, n_embd=config.n_embd, n_layer=config.n_layer, n_head=config.n_head, pdrop=config.pdrop, token_pdrop=config.token_pdrop ).to(device) if is_main and hasattr(model, "get_num_params"): print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") loss_fn = CombinedLoss(config.ignored_token_ids) optimizer = Adam(model.parameters(), lr=config.lr_initial) # DDP 包装 if is_dist: model = nn.parallel.DistributedDataParallel( model, device_ids=[local_rank] if device.type == "cuda" else None, output_device=local_rank if device.type == "cuda" else None, find_unused_parameters=False, ) # --- 3. Training Loop --- best_val_loss = float('inf') patience_counter = 0 # 只在主进程收集与画图 train_losses_ce, train_losses_surv, train_losses_total = [], [], [] val_losses_ce, val_losses_surv, val_losses_total = [], [], [] if is_main: print("Starting training...") stop_training = False for epoch in range(config.max_epoch): # 设置 epoch 给分布式采样器,确保跨 epoch shuffle if is_dist: train_sampler.set_epoch(epoch) # --- Learning Rate Scheduling --- if epoch < config.warmup_epochs: lr = config.lr_initial else: progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) for param_group in optimizer.param_groups: param_group['lr'] = lr # --- Training Phase --- if is_main: pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") else: pbar = train_loader # 非主进程禁用 tqdm model.train() train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0 train_steps = 0 for batch in pbar: event_seq, time_seq = batch event_seq = event_seq.to(device, non_blocking=True) time_seq = time_seq.to(device, non_blocking=True) # Prepare inputs and targets input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() # Forward pass logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss = loss_ce + loss_survival # Backward pass and optimization optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() train_loss_ce_acc += float(loss_ce.item()) train_loss_surv_acc += float(loss_survival.item()) train_steps += 1 if is_main and isinstance(pbar, tqdm.tqdm): pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) # 进程内均值 avg_train_loss_ce_local = train_loss_ce_acc / max(train_steps, 1) avg_train_loss_surv_local = train_loss_surv_acc / max(train_steps, 1) # 所有进程平均 if is_dist: avg_train_loss_ce = all_reduce_mean(avg_train_loss_ce_local, device, world_size) avg_train_loss_surv = all_reduce_mean(avg_train_loss_surv_local, device, world_size) else: avg_train_loss_ce = avg_train_loss_ce_local avg_train_loss_surv = avg_train_loss_surv_local if is_main: train_losses_ce.append(avg_train_loss_ce) train_losses_surv.append(avg_train_loss_surv) train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv) # --- Validation Phase --- if is_main: pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]") else: pbar_val = val_loader model.eval() val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 val_steps = 0 with torch.no_grad(): for batch in pbar_val: event_seq, time_seq = batch event_seq = event_seq.to(device, non_blocking=True) time_seq = time_seq.to(device, non_blocking=True) input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) val_loss_ce_acc += float(loss_ce.item()) val_loss_surv_acc += float(loss_survival.item()) val_steps += 1 if is_main and isinstance(pbar_val, tqdm.tqdm): pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'}) avg_val_loss_ce_local = val_loss_ce_acc / max(val_steps, 1) avg_val_loss_surv_local = val_loss_surv_acc / max(val_steps, 1) if is_dist: avg_val_loss_ce = all_reduce_mean(avg_val_loss_ce_local, device, world_size) avg_val_loss_surv = all_reduce_mean(avg_val_loss_surv_local, device, world_size) else: avg_val_loss_ce = avg_val_loss_ce_local avg_val_loss_surv = avg_val_loss_surv_local total_val_loss = avg_val_loss_ce + avg_val_loss_surv # 主进程打印与记录 if is_main: print(f"Epoch {epoch+1} Summary: \n" f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" f" Learning Rate: {lr:.6f}") val_losses_ce.append(avg_val_loss_ce) val_losses_surv.append(avg_val_loss_surv) val_losses_total.append(total_val_loss) # --- Early Stopping Check (基于聚合后的 total_val_loss) --- improved = False if is_main: if total_val_loss < best_val_loss: best_val_loss = total_val_loss patience_counter = 0 improved = True print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") # DDP: 保存 module.state_dict() state_dict = model.module.state_dict() if isinstance(model, nn.parallel.DistributedDataParallel) else model.state_dict() torch.save(state_dict, 'best_model_checkpoint.pt') else: if epoch >= config.warmup_epochs: patience_counter += 1 print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}") stop_training = patience_counter >= config.early_stopping_patience # 把 improved/stop 广播到所有进程,确保一致退出 if is_dist: flag_tensor = torch.tensor([1 if stop_training else 0], device=device, dtype=torch.int32) dist.broadcast(flag_tensor, src=0) stop_training = bool(int(flag_tensor.item())) if stop_training: if is_main: print("\nEarly stopping triggered due to no improvement in validation loss.") break # --- Save Best Model at the End (只主进程) --- if is_main: if best_val_loss != float('inf'): print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") # 为了易用,这里在主进程上重新构建单卡模型加载权重再保存 model_single = TimeAwareGPT2( vocab_size=vocab_size, n_embd=config.n_embd, n_layer=config.n_layer, n_head=config.n_head, pdrop=config.pdrop, token_pdrop=config.token_pdrop ).to('cpu') model_single.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu')) print("Saving final best model to best_model.pt") torch.save(model_single.state_dict(), 'best_model.pt') else: print("\nTraining finished. No best model to save as validation loss never improved.") # --- Plot and Save Loss Curves --- num_epochs = len(train_losses_total) if num_epochs > 0: epochs = range(1, num_epochs + 1) plt.figure(figsize=(18, 5)) # Plot CE Loss plt.subplot(1, 3, 1) plt.plot(epochs, train_losses_ce, label='Train CE') plt.plot(epochs, val_losses_ce, label='Val CE') plt.title('Cross-Entropy Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.grid(True) # Plot Survival Loss plt.subplot(1, 3, 2) plt.plot(epochs, train_losses_surv, label='Train Survival') plt.plot(epochs, val_losses_surv, label='Val Survival') plt.title('Survival Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.grid(True) # Plot Total Loss plt.subplot(1, 3, 3) plt.plot(epochs, train_losses_total, label='Train Total') plt.plot(epochs, val_losses_total, label='Val Total') plt.title('Total Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.grid(True) plt.tight_layout() plt.savefig('loss_curves.png') print("\nLoss curves saved to loss_curves.png") # 清理分布式 cleanup_distributed() if __name__ == '__main__': main()