From e3e533c9ece60e1c91f8dcb916a3ffd75bf92265 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Thu, 16 Oct 2025 16:58:30 +0800 Subject: [PATCH] update --- train_dpp.py | 440 ++++++++++++++++++++++++++++----------------------- 1 file changed, 245 insertions(+), 195 deletions(-) diff --git a/train_dpp.py b/train_dpp.py index eb729fd..bb43e2f 100644 --- a/train_dpp.py +++ b/train_dpp.py @@ -1,20 +1,21 @@ -import torch -import torch.nn as nn -from torch.optim import Adam -from torch.utils.data import DataLoader, DistributedSampler -import torch.distributed as dist -import torch.multiprocessing as mp -from torch.nn.parallel import DistributedDataParallel as DDP -import numpy as np +# train.py (DDP-ready) +import os import math +import argparse +import numpy as np import tqdm import matplotlib.pyplot as plt -import os -import time + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.optim import Adam +from torch.utils.data import DataLoader, DistributedSampler from models import TimeAwareGPT2, CombinedLoss from utils import PatientEventDataset + # --- Configuration --- class TrainConfig: # Data parameters @@ -31,90 +32,120 @@ class TrainConfig: # Training parameters max_epoch = 200 - batch_size = 512 # 增大总批次大小 + batch_size = 128 lr_initial = 6e-4 lr_final = 6e-5 warmup_epochs = 10 early_stopping_patience = 5 - + # Loss parameters + # 0 = padding, 1 = "no event" ignored_token_ids = [0, 1] - # Distributed training parameters - world_size = torch.cuda.device_count() - distributed = world_size > 1 + # System parameters (device 将在 main() 内按 local_rank 动态设置) + device = 'cuda' if torch.cuda.is_available() else 'cpu' -# --- Main Training Function --- -def train_worker(local_rank, config): - # Initialize distributed training - if config.distributed: - dist.init_process_group( - backend='nccl', - init_method='env://', - rank=local_rank, - world_size=config.world_size - ) + +def setup_distributed(backend: str = "nccl"): + """ + 如果由 torchrun 启动且 WORLD_SIZE>1,则初始化分布式。 + 返回 (is_distributed, world_size, rank, local_rank) + """ + world_size = int(os.environ.get("WORLD_SIZE", "1")) + is_distributed = world_size > 1 + if is_distributed: + if not dist.is_initialized(): + dist.init_process_group(backend=backend, init_method="env://") + rank = dist.get_rank() + local_rank = int(os.environ.get("LOCAL_RANK", "0")) torch.cuda.set_device(local_rank) - device = torch.device('cuda', local_rank) - print(f"Worker {local_rank} initialized on device {device}") else: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + rank = 0 local_rank = 0 + return is_distributed, world_size, rank, local_rank + + +def cleanup_distributed(): + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group() + + +def all_reduce_mean(value: float, device, world_size: int): + """ + value 是 Python float(本进程的和/均值),返回所有进程平均后的 float。 + """ + tensor = torch.tensor([value], dtype=torch.float32, device=device) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor /= world_size + return float(tensor.item()) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--backend", type=str, default="nccl", choices=["nccl", "gloo", "mpi"]) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + # 分布式初始化 + is_dist, world_size, rank, local_rank = setup_distributed(args.backend) + + # 基本环境 + torch.manual_seed(args.seed + rank) + np.random.seed(args.seed + rank) + torch.backends.cudnn.benchmark = True + + config = TrainConfig() + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + config.device = device + + is_main = (rank == 0) # --- 1. Data Loading --- - if local_rank == 0: + if is_main: print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") - train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + # Infer vocab_size from the data (max label + 1) vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 - - if local_rank == 0: + if is_main: print(f"Inferred vocabulary size: {vocab_size}") - print(f"Using {config.world_size} GPU(s) for training") train_dataset = PatientEventDataset(train_data_arr, config.block_length) val_dataset = PatientEventDataset(val_data_arr, config.block_length) - # 计算每个GPU的批次大小 - per_gpu_batch_size = config.batch_size // config.world_size - - # 优化数据加载器参数 - if config.distributed: - train_sampler = DistributedSampler(train_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=True) - val_sampler = DistributedSampler(val_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=False) + # 分布式采样器 + if is_dist: + train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) + val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False) else: train_sampler = None val_sampler = None - # 增加num_workers,使用persistent_workers减少进程创建开销 train_loader = DataLoader( - train_dataset, - batch_size=per_gpu_batch_size, # 使用每个GPU的批次大小 - sampler=train_sampler, + train_dataset, + batch_size=config.batch_size, shuffle=(train_sampler is None), - num_workers=8, # 增加worker数量 + sampler=train_sampler, + num_workers=4, pin_memory=True, - persistent_workers=True, # 保持worker进程 - prefetch_factor=2 # 预取批次 + drop_last=False, + persistent_workers=True if 4 > 0 else False, ) - val_loader = DataLoader( - val_dataset, - batch_size=per_gpu_batch_size, - sampler=val_sampler, + val_dataset, + batch_size=config.batch_size, shuffle=False, - num_workers=8, + sampler=val_sampler, + num_workers=4, pin_memory=True, - persistent_workers=True, - prefetch_factor=2 + drop_last=False, + persistent_workers=True if 4 > 0 else False, ) # --- 2. Model, Optimizer, and Loss Initialization --- - if local_rank == 0: - print(f"Initializing model on {device}...") - + if is_main: + print(f"Initializing model on {config.device}...") model = TimeAwareGPT2( vocab_size=vocab_size, n_embd=config.n_embd, @@ -124,36 +155,37 @@ def train_worker(local_rank, config): token_pdrop=config.token_pdrop ).to(device) - # 使用梯度累积来模拟更大的批次大小,减少通信频率 - if config.distributed: - # 使用find_unused_parameters=False来加速 - model = DDP(model, device_ids=[local_rank], output_device=local_rank, - find_unused_parameters=False) - - if local_rank == 0: - if config.distributed: - num_params = sum(p.numel() for p in model.module.parameters() if p.requires_grad) - else: - num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"Model initialized with {num_params/1e6:.2f}M trainable parameters.") - print(f"Per GPU batch size: {per_gpu_batch_size}") + if is_main and hasattr(model, "get_num_params"): + print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") loss_fn = CombinedLoss(config.ignored_token_ids) optimizer = Adam(model.parameters(), lr=config.lr_initial) + # DDP 包装 + if is_dist: + model = nn.parallel.DistributedDataParallel( + model, + device_ids=[local_rank] if device.type == "cuda" else None, + output_device=local_rank if device.type == "cuda" else None, + find_unused_parameters=False, + ) + # --- 3. Training Loop --- best_val_loss = float('inf') patience_counter = 0 - - if local_rank == 0: - train_losses_ce, train_losses_surv, train_losses_total = [], [], [] - val_losses_ce, val_losses_surv, val_losses_total = [], [], [] - if local_rank == 0: + # 只在主进程收集与画图 + train_losses_ce, train_losses_surv, train_losses_total = [], [], [] + val_losses_ce, val_losses_surv, val_losses_total = [], [], [] + + if is_main: print("Starting training...") - + + stop_training = False + for epoch in range(config.max_epoch): - if config.distributed: + # 设置 epoch 给分布式采样器,确保跨 epoch shuffle + if is_dist: train_sampler.set_epoch(epoch) # --- Learning Rate Scheduling --- @@ -162,24 +194,23 @@ def train_worker(local_rank, config): else: progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) - for param_group in optimizer.param_groups: param_group['lr'] = lr # --- Training Phase --- + if is_main: + pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") + else: + pbar = train_loader # 非主进程禁用 tqdm + model.train() train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0 train_steps = 0 - - # 只在rank 0显示进度条 - if local_rank == 0: - pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") - else: - pbar = train_loader - - batch_start_time = time.time() - for batch_idx, (event_seq, time_seq) in enumerate(pbar): - event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True) + + for batch in pbar: + event_seq, time_seq = batch + event_seq = event_seq.to(device, non_blocking=True) + time_seq = time_seq.to(device, non_blocking=True) # Prepare inputs and targets input_events = event_seq[:, :-1] @@ -193,52 +224,49 @@ def train_worker(local_rank, config): loss = loss_ce + loss_survival # Backward pass and optimization - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True) loss.backward() - - # 梯度同步在DDP中自动处理 optimizer.step() - # 异步记录损失,避免同步阻塞 - train_loss_ce_acc += loss_ce.item() - train_loss_surv_acc += loss_survival.item() + train_loss_ce_acc += float(loss_ce.item()) + train_loss_surv_acc += float(loss_survival.item()) train_steps += 1 - - if local_rank == 0 and batch_idx % 10 == 0: # 每10个批次更新一次 - batch_time = time.time() - batch_start_time - pbar.set_postfix({ - 'loss_ce': f'{loss_ce.item():.4f}', - 'loss_surv': f'{loss_survival.item():.4f}', - 'lr': f'{lr:.2e}', - 'batch_time': f'{batch_time:.3f}s' - }) - batch_start_time = time.time() - # 只在epoch结束时同步一次损失,减少通信 - if config.distributed: - # 使用all_reduce同步损失 - train_loss_ce_tensor = torch.tensor([train_loss_ce_acc], device=device) - train_loss_surv_tensor = torch.tensor([train_loss_surv_acc], device=device) - train_steps_tensor = torch.tensor([train_steps], device=device) - - dist.all_reduce(train_loss_ce_tensor) - dist.all_reduce(train_loss_surv_tensor) - dist.all_reduce(train_steps_tensor) - - avg_train_loss_ce = (train_loss_ce_tensor.item() / train_steps_tensor.item()) - avg_train_loss_surv = (train_loss_surv_tensor.item() / train_steps_tensor.item()) + if is_main and isinstance(pbar, tqdm.tqdm): + pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) + + # 进程内均值 + avg_train_loss_ce_local = train_loss_ce_acc / max(train_steps, 1) + avg_train_loss_surv_local = train_loss_surv_acc / max(train_steps, 1) + + # 所有进程平均 + if is_dist: + avg_train_loss_ce = all_reduce_mean(avg_train_loss_ce_local, device, world_size) + avg_train_loss_surv = all_reduce_mean(avg_train_loss_surv_local, device, world_size) else: - avg_train_loss_ce = train_loss_ce_acc / train_steps - avg_train_loss_surv = train_loss_surv_acc / train_steps + avg_train_loss_ce = avg_train_loss_ce_local + avg_train_loss_surv = avg_train_loss_surv_local + + if is_main: + train_losses_ce.append(avg_train_loss_ce) + train_losses_surv.append(avg_train_loss_surv) + train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv) # --- Validation Phase --- + if is_main: + pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]") + else: + pbar_val = val_loader + model.eval() val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 val_steps = 0 with torch.no_grad(): - for event_seq, time_seq in val_loader: - event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True) + for batch in pbar_val: + event_seq, time_seq = batch + event_seq = event_seq.to(device, non_blocking=True) + time_seq = time_seq.to(device, non_blocking=True) input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] @@ -247,104 +275,126 @@ def train_worker(local_rank, config): logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) - - val_loss_ce_acc += loss_ce.item() - val_loss_surv_acc += loss_survival.item() + + val_loss_ce_acc += float(loss_ce.item()) + val_loss_surv_acc += float(loss_survival.item()) val_steps += 1 - # 同步验证损失 - if config.distributed: - val_loss_ce_tensor = torch.tensor([val_loss_ce_acc], device=device) - val_loss_surv_tensor = torch.tensor([val_loss_surv_acc], device=device) - val_steps_tensor = torch.tensor([val_steps], device=device) - - dist.all_reduce(val_loss_ce_tensor) - dist.all_reduce(val_loss_surv_tensor) - dist.all_reduce(val_steps_tensor) - - avg_val_loss_ce = (val_loss_ce_tensor.item() / val_steps_tensor.item()) - avg_val_loss_surv = (val_loss_surv_tensor.item() / val_steps_tensor.item()) + if is_main and isinstance(pbar_val, tqdm.tqdm): + pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'}) + + avg_val_loss_ce_local = val_loss_ce_acc / max(val_steps, 1) + avg_val_loss_surv_local = val_loss_surv_acc / max(val_steps, 1) + + if is_dist: + avg_val_loss_ce = all_reduce_mean(avg_val_loss_ce_local, device, world_size) + avg_val_loss_surv = all_reduce_mean(avg_val_loss_surv_local, device, world_size) else: - avg_val_loss_ce = val_loss_ce_acc / val_steps - avg_val_loss_surv = val_loss_surv_acc / val_steps + avg_val_loss_ce = avg_val_loss_ce_local + avg_val_loss_surv = avg_val_loss_surv_local total_val_loss = avg_val_loss_ce + avg_val_loss_surv - - # 只在rank 0进行打印和保存 - if local_rank == 0: - train_losses_ce.append(avg_train_loss_ce) - train_losses_surv.append(avg_train_loss_surv) - train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv) - val_losses_ce.append(avg_val_loss_ce) - val_losses_surv.append(avg_val_loss_surv) - val_losses_total.append(total_val_loss) + # 主进程打印与记录 + if is_main: print(f"Epoch {epoch+1} Summary: \n" f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" f" Learning Rate: {lr:.6f}") + val_losses_ce.append(avg_val_loss_ce) + val_losses_surv.append(avg_val_loss_surv) + val_losses_total.append(total_val_loss) - # Early stopping check + # --- Early Stopping Check (基于聚合后的 total_val_loss) --- + improved = False + if is_main: if total_val_loss < best_val_loss: best_val_loss = total_val_loss patience_counter = 0 + improved = True print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") - if config.distributed: - torch.save(model.module.state_dict(), 'best_model_checkpoint.pt') - else: - torch.save(model.state_dict(), 'best_model_checkpoint.pt') + # DDP: 保存 module.state_dict() + state_dict = model.module.state_dict() if isinstance(model, nn.parallel.DistributedDataParallel) else model.state_dict() + torch.save(state_dict, 'best_model_checkpoint.pt') else: if epoch >= config.warmup_epochs: patience_counter += 1 print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}") + stop_training = patience_counter >= config.early_stopping_patience - if patience_counter >= config.early_stopping_patience: + # 把 improved/stop 广播到所有进程,确保一致退出 + if is_dist: + flag_tensor = torch.tensor([1 if stop_training else 0], device=device, dtype=torch.int32) + dist.broadcast(flag_tensor, src=0) + stop_training = bool(int(flag_tensor.item())) + + if stop_training: + if is_main: print("\nEarly stopping triggered due to no improvement in validation loss.") - if config.distributed: - stop_signal = torch.tensor(1, device=device) - dist.broadcast(stop_signal, 0) - break + break + + # --- Save Best Model at the End (只主进程) --- + if is_main: + if best_val_loss != float('inf'): + print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") + # 为了易用,这里在主进程上重新构建单卡模型加载权重再保存 + model_single = TimeAwareGPT2( + vocab_size=vocab_size, + n_embd=config.n_embd, + n_layer=config.n_layer, + n_head=config.n_head, + pdrop=config.pdrop, + token_pdrop=config.token_pdrop + ).to('cpu') + model_single.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu')) + print("Saving final best model to best_model.pt") + torch.save(model_single.state_dict(), 'best_model.pt') else: - # 非rank 0进程检查停止信号 - if config.distributed: - stop_signal = torch.tensor(0, device=device) - dist.broadcast(stop_signal, 0) - if stop_signal.item() == 1: - break + print("\nTraining finished. No best model to save as validation loss never improved.") - # 清理和保存 - if local_rank == 0 and best_val_loss != float('inf'): - print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") - if config.distributed: - model.module.load_state_dict(torch.load('best_model_checkpoint.pt')) - torch.save(model.module.state_dict(), 'best_model.pt') - else: - model.load_state_dict(torch.load('best_model_checkpoint.pt')) - torch.save(model.state_dict(), 'best_model.pt') - print("Final best model saved to best_model.pt") + # --- Plot and Save Loss Curves --- + num_epochs = len(train_losses_total) + if num_epochs > 0: + epochs = range(1, num_epochs + 1) + plt.figure(figsize=(18, 5)) - if config.distributed: - dist.destroy_process_group() + # Plot CE Loss + plt.subplot(1, 3, 1) + plt.plot(epochs, train_losses_ce, label='Train CE') + plt.plot(epochs, val_losses_ce, label='Val CE') + plt.title('Cross-Entropy Loss') + plt.xlabel('Epochs') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + # Plot Survival Loss + plt.subplot(1, 3, 2) + plt.plot(epochs, train_losses_surv, label='Train Survival') + plt.plot(epochs, val_losses_surv, label='Val Survival') + plt.title('Survival Loss') + plt.xlabel('Epochs') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + # Plot Total Loss + plt.subplot(1, 3, 3) + plt.plot(epochs, train_losses_total, label='Train Total') + plt.plot(epochs, val_losses_total, label='Val Total') + plt.title('Total Loss') + plt.xlabel('Epochs') + plt.ylabel('Loss') + plt.legend() + plt.grid(True) + + plt.tight_layout() + plt.savefig('loss_curves.png') + print("\nLoss curves saved to loss_curves.png") + + # 清理分布式 + cleanup_distributed() -def main(): - config = TrainConfig() - - # 设置环境变量优化 - os.environ['CUDA_LAUNCH_BLOCKING'] = '0' # 减少同步 - os.environ['NCCL_DEBUG'] = 'WARN' # 减少NCCL日志 - os.environ['NCCL_SOCKET_IFNAME'] = '^lo,docker' # 选择正确的网络接口 - - if config.distributed: - print(f"Starting distributed training with {config.world_size} GPUs") - mp.spawn( - train_worker, - args=(config,), - nprocs=config.world_size, - join=True - ) - else: - print("Starting single GPU training") - train_worker(0, config) if __name__ == '__main__': - main() \ No newline at end of file + main()