diff --git a/train_dpp.py b/train_dpp.py index 5f18a36..eb729fd 100644 --- a/train_dpp.py +++ b/train_dpp.py @@ -1,16 +1,16 @@ -# train_ddp.py -import os -import math -import numpy as np -import tqdm -import matplotlib.pyplot as plt - import torch import torch.nn as nn -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Adam from torch.utils.data import DataLoader, DistributedSampler +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +import numpy as np +import math +import tqdm +import matplotlib.pyplot as plt +import os +import time from models import TimeAwareGPT2, CombinedLoss from utils import PatientEventDataset @@ -19,8 +19,8 @@ from utils import PatientEventDataset class TrainConfig: # Data parameters train_data_path = 'ukb_real_train.bin' - val_data_path = 'ukb_real_val.bin' - block_length = 24 # Sequence length + val_data_path = 'ukb_real_val.bin' + block_length = 24 # Sequence length # Model parameters n_embd = 256 @@ -31,117 +31,90 @@ class TrainConfig: # Training parameters max_epoch = 200 - batch_size = 128 + batch_size = 512 # 增大总批次大小 lr_initial = 6e-4 lr_final = 6e-5 warmup_epochs = 10 early_stopping_patience = 5 - + # Loss parameters - # 0 = padding, 1 = "no event" ignored_token_ids = [0, 1] - # System parameters - device = 'cuda' if torch.cuda.is_available() else 'cpu' + # Distributed training parameters + world_size = torch.cuda.device_count() + distributed = world_size > 1 -def ddp_is_active(): - return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 - -def ddp_setup(): - """Initialize process group if launched by torchrun.""" - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: - backend = 'nccl' if torch.cuda.is_available() else 'gloo' - dist.init_process_group(backend=backend, init_method='env://') - return True - return False - -def ddp_cleanup(): - if dist.is_initialized(): - dist.barrier() - dist.destroy_process_group() - -def all_reduce_mean(value: torch.Tensor): - """Mean across processes (no-op if not DDP).""" - if ddp_is_active(): - dist.all_reduce(value, op=dist.ReduceOp.SUM) - value /= dist.get_world_size() - return value - -def bcast_bool(stop: bool): - if not ddp_is_active(): - return stop - t = torch.tensor(1 if stop else 0, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) - dist.broadcast(t, src=0) - return bool(t.item()) - -def main(): - config = TrainConfig() - - # --- DDP setup --- - ddp_enabled = ddp_setup() - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - global_rank = int(os.environ.get("RANK", 0)) if ddp_enabled else 0 - world_size = int(os.environ.get("WORLD_SIZE", 1)) if ddp_enabled else 1 - - if torch.cuda.is_available(): +# --- Main Training Function --- +def train_worker(local_rank, config): + # Initialize distributed training + if config.distributed: + dist.init_process_group( + backend='nccl', + init_method='env://', + rank=local_rank, + world_size=config.world_size + ) torch.cuda.set_device(local_rank) device = torch.device('cuda', local_rank) + print(f"Worker {local_rank} initialized on device {device}") else: - device = torch.device(config.device) - - # Seed per-rank (different but deterministic) - torch.manual_seed(1337 + global_rank) - np.random.seed(1337 + global_rank) - - is_main = (global_rank == 0) - if is_main: - print(f"DDP enabled: {ddp_enabled} | world_size={world_size} | device={device}") + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + local_rank = 0 # --- 1. Data Loading --- - if is_main: + if local_rank == 0: print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") + train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) - val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) - # Infer vocab_size from data (max label + 1) vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 - if is_main: + + if local_rank == 0: print(f"Inferred vocabulary size: {vocab_size}") + print(f"Using {config.world_size} GPU(s) for training") train_dataset = PatientEventDataset(train_data_arr, config.block_length) - val_dataset = PatientEventDataset(val_data_arr, config.block_length) + val_dataset = PatientEventDataset(val_data_arr, config.block_length) - # DDP samplers (fall back to regular shuffle when not DDP) - if ddp_enabled: - train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True, drop_last=False) - val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=global_rank, shuffle=False, drop_last=False) - shuffle_flag = False + # 计算每个GPU的批次大小 + per_gpu_batch_size = config.batch_size // config.world_size + + # 优化数据加载器参数 + if config.distributed: + train_sampler = DistributedSampler(train_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=True) + val_sampler = DistributedSampler(val_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=False) else: - train_sampler, val_sampler = None, None - shuffle_flag = True + train_sampler = None + val_sampler = None + # 增加num_workers,使用persistent_workers减少进程创建开销 train_loader = DataLoader( - train_dataset, - batch_size=config.batch_size, - shuffle=shuffle_flag, + train_dataset, + batch_size=per_gpu_batch_size, # 使用每个GPU的批次大小 sampler=train_sampler, - num_workers=4, + shuffle=(train_sampler is None), + num_workers=8, # 增加worker数量 pin_memory=True, - drop_last=False, + persistent_workers=True, # 保持worker进程 + prefetch_factor=2 # 预取批次 ) + val_loader = DataLoader( - val_dataset, - batch_size=config.batch_size, - shuffle=False, + val_dataset, + batch_size=per_gpu_batch_size, sampler=val_sampler, - num_workers=4, + shuffle=False, + num_workers=8, pin_memory=True, - drop_last=False, + persistent_workers=True, + prefetch_factor=2 ) - # --- 2. Model, Optimizer, Loss --- - if is_main: + # --- 2. Model, Optimizer, and Loss Initialization --- + if local_rank == 0: print(f"Initializing model on {device}...") + model = TimeAwareGPT2( vocab_size=vocab_size, n_embd=config.n_embd, @@ -151,167 +124,175 @@ def main(): token_pdrop=config.token_pdrop ).to(device) - if is_main: - # If your model has get_num_params - try: - print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") - except Exception: - total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"Model initialized with {total_params/1e6:.2f}M trainable parameters.") + # 使用梯度累积来模拟更大的批次大小,减少通信频率 + if config.distributed: + # 使用find_unused_parameters=False来加速 + model = DDP(model, device_ids=[local_rank], output_device=local_rank, + find_unused_parameters=False) - loss_fn = CombinedLoss(config.ignored_token_ids) + if local_rank == 0: + if config.distributed: + num_params = sum(p.numel() for p in model.module.parameters() if p.requires_grad) + else: + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Model initialized with {num_params/1e6:.2f}M trainable parameters.") + print(f"Per GPU batch size: {per_gpu_batch_size}") + + loss_fn = CombinedLoss(config.ignored_token_ids) optimizer = Adam(model.parameters(), lr=config.lr_initial) - # Wrap with DDP - if ddp_enabled: - model = DDP(model, device_ids=[local_rank] if device.type == 'cuda' else None, - output_device=local_rank if device.type == 'cuda' else None, - find_unused_parameters=False) - - # AMP - use_amp = (device.type == 'cuda') - scaler = torch.cuda.amp.GradScaler(enabled=use_amp) - # --- 3. Training Loop --- best_val_loss = float('inf') patience_counter = 0 + + if local_rank == 0: + train_losses_ce, train_losses_surv, train_losses_total = [], [], [] + val_losses_ce, val_losses_surv, val_losses_total = [], [], [] - # store losses only on main to plot - train_losses_ce, train_losses_surv, train_losses_total = [], [], [] - val_losses_ce, val_losses_surv, val_losses_total = [], [], [] - - if is_main: + if local_rank == 0: print("Starting training...") - + for epoch in range(config.max_epoch): - # Ensure different shuffles per epoch under DDP - if ddp_enabled: + if config.distributed: train_sampler.set_epoch(epoch) - val_sampler.set_epoch(epoch) - # --- LR scheduling (same as original) --- + # --- Learning Rate Scheduling --- if epoch < config.warmup_epochs: lr = config.lr_initial else: progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) - for pg in optimizer.param_groups: - pg['lr'] = lr + + for param_group in optimizer.param_groups: + param_group['lr'] = lr - # --- Training --- - if is_main: + # --- Training Phase --- + model.train() + train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0 + train_steps = 0 + + # 只在rank 0显示进度条 + if local_rank == 0: pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") else: - pbar = train_loader # silent on non-main ranks - - model.train() - train_ce_sum = torch.tensor(0.0, device=device) - train_surv_sum = torch.tensor(0.0, device=device) - train_steps = 0 - - for batch in pbar: - event_seq, time_seq = batch - event_seq = event_seq.to(device, non_blocking=True) - time_seq = time_seq.to(device, non_blocking=True) + pbar = train_loader + + batch_start_time = time.time() + for batch_idx, (event_seq, time_seq) in enumerate(pbar): + event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True) + # Prepare inputs and targets input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() - optimizer.zero_grad(set_to_none=True) - with torch.cuda.amp.autocast(enabled=use_amp): - logits = model(input_events, input_times) - loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) - loss = loss_ce + loss_survival + # Forward pass + logits = model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + loss = loss_ce + loss_survival - if use_amp: - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - optimizer.step() + # Backward pass and optimization + optimizer.zero_grad() + loss.backward() + + # 梯度同步在DDP中自动处理 + optimizer.step() - train_ce_sum += loss_ce.detach() - train_surv_sum += loss_survival.detach() + # 异步记录损失,避免同步阻塞 + train_loss_ce_acc += loss_ce.item() + train_loss_surv_acc += loss_survival.item() train_steps += 1 + + if local_rank == 0 and batch_idx % 10 == 0: # 每10个批次更新一次 + batch_time = time.time() - batch_start_time + pbar.set_postfix({ + 'loss_ce': f'{loss_ce.item():.4f}', + 'loss_surv': f'{loss_survival.item():.4f}', + 'lr': f'{lr:.2e}', + 'batch_time': f'{batch_time:.3f}s' + }) + batch_start_time = time.time() - if is_main and isinstance(pbar, tqdm.tqdm): - pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', - 'loss_surv': f'{loss_survival.item():.4f}', - 'lr': f'{lr:.2e}'}) - - # Average train losses across ranks - train_ce_avg = train_ce_sum / max(1, train_steps) - train_surv_avg = train_surv_sum / max(1, train_steps) - train_ce_avg = all_reduce_mean(train_ce_avg) - train_surv_avg = all_reduce_mean(train_surv_avg) - train_total_avg = train_ce_avg + train_surv_avg - - # --- Validation --- - if is_main: - pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]") + # 只在epoch结束时同步一次损失,减少通信 + if config.distributed: + # 使用all_reduce同步损失 + train_loss_ce_tensor = torch.tensor([train_loss_ce_acc], device=device) + train_loss_surv_tensor = torch.tensor([train_loss_surv_acc], device=device) + train_steps_tensor = torch.tensor([train_steps], device=device) + + dist.all_reduce(train_loss_ce_tensor) + dist.all_reduce(train_loss_surv_tensor) + dist.all_reduce(train_steps_tensor) + + avg_train_loss_ce = (train_loss_ce_tensor.item() / train_steps_tensor.item()) + avg_train_loss_surv = (train_loss_surv_tensor.item() / train_steps_tensor.item()) else: - pbar_val = val_loader + avg_train_loss_ce = train_loss_ce_acc / train_steps + avg_train_loss_surv = train_loss_surv_acc / train_steps + # --- Validation Phase --- model.eval() - val_ce_sum = torch.tensor(0.0, device=device) - val_surv_sum = torch.tensor(0.0, device=device) + val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 val_steps = 0 + with torch.no_grad(): - for batch in pbar_val: - event_seq, time_seq = batch - event_seq = event_seq.to(device, non_blocking=True) - time_seq = time_seq.to(device, non_blocking=True) + for event_seq, time_seq in val_loader: + event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True) input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() - with torch.cuda.amp.autocast(enabled=use_amp): - logits = model(input_events, input_times) - loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) - - val_ce_sum += loss_ce.detach() - val_surv_sum += loss_survival.detach() + logits = model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + + val_loss_ce_acc += loss_ce.item() + val_loss_surv_acc += loss_survival.item() val_steps += 1 - if is_main and isinstance(pbar_val, tqdm.tqdm): - pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', - 'loss_surv': f'{loss_survival.item():.4f}'}) + # 同步验证损失 + if config.distributed: + val_loss_ce_tensor = torch.tensor([val_loss_ce_acc], device=device) + val_loss_surv_tensor = torch.tensor([val_loss_surv_acc], device=device) + val_steps_tensor = torch.tensor([val_steps], device=device) + + dist.all_reduce(val_loss_ce_tensor) + dist.all_reduce(val_loss_surv_tensor) + dist.all_reduce(val_steps_tensor) + + avg_val_loss_ce = (val_loss_ce_tensor.item() / val_steps_tensor.item()) + avg_val_loss_surv = (val_loss_surv_tensor.item() / val_steps_tensor.item()) + else: + avg_val_loss_ce = val_loss_ce_acc / val_steps + avg_val_loss_surv = val_loss_surv_acc / val_steps - # Average val losses across ranks - val_ce_avg = val_ce_sum / max(1, val_steps) - val_surv_avg = val_surv_sum / max(1, val_steps) - val_ce_avg = all_reduce_mean(val_ce_avg) - val_surv_avg = all_reduce_mean(val_surv_avg) - val_total_avg = val_ce_avg + val_surv_avg + total_val_loss = avg_val_loss_ce + avg_val_loss_surv + + # 只在rank 0进行打印和保存 + if local_rank == 0: + train_losses_ce.append(avg_train_loss_ce) + train_losses_surv.append(avg_train_loss_surv) + train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv) + val_losses_ce.append(avg_val_loss_ce) + val_losses_surv.append(avg_val_loss_surv) + val_losses_total.append(total_val_loss) - # --- Logging & Early Stopping (rank 0) --- - stop_now = False - if is_main: - print(f"Epoch {epoch+1} Summary:\n" - f" Train Loss: {float(train_total_avg):.4f} (CE: {float(train_ce_avg):.4f}, Surv: {float(train_surv_avg):.4f})\n" - f" Val Loss: {float(val_total_avg):.4f} (CE: {float(val_ce_avg):.4f}, Surv: {float(val_surv_avg):.4f})\n" + print(f"Epoch {epoch+1} Summary: \n" + f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" + f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" f" Learning Rate: {lr:.6f}") - # Record for curves - train_losses_ce.append(float(train_ce_avg)) - train_losses_surv.append(float(train_surv_avg)) - train_losses_total.append(float(train_total_avg)) - val_losses_ce.append(float(val_ce_avg)) - val_losses_surv.append(float(val_surv_avg)) - val_losses_total.append(float(val_total_avg)) - - if val_total_avg < best_val_loss: - best_val_loss = float(val_total_avg) + # Early stopping check + if total_val_loss < best_val_loss: + best_val_loss = total_val_loss patience_counter = 0 - print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint (rank0)...") - # unwrap DDP - to_save = model.module if isinstance(model, DDP) else model - torch.save(to_save.state_dict(), 'best_model_checkpoint.pt') + print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") + if config.distributed: + torch.save(model.module.state_dict(), 'best_model_checkpoint.pt') + else: + torch.save(model.state_dict(), 'best_model_checkpoint.pt') else: if epoch >= config.warmup_epochs: patience_counter += 1 @@ -319,58 +300,51 @@ def main(): if patience_counter >= config.early_stopping_patience: print("\nEarly stopping triggered due to no improvement in validation loss.") - stop_now = True - - # Broadcast stop flag so all ranks exit together - stop_now = bcast_bool(stop_now) - if stop_now: - break - - # --- Save Best Model at the End (rank 0 only) --- - if is_main: - if best_val_loss != float('inf'): - print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") - # unwrap for loading & final save - to_load = model.module if isinstance(model, DDP) else model - to_load.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu')) - print("Saving final best model to best_model.pt") - torch.save(to_load.state_dict(), 'best_model.pt') + if config.distributed: + stop_signal = torch.tensor(1, device=device) + dist.broadcast(stop_signal, 0) + break else: - print("\nTraining finished. No best model to save as validation loss never improved.") + # 非rank 0进程检查停止信号 + if config.distributed: + stop_signal = torch.tensor(0, device=device) + dist.broadcast(stop_signal, 0) + if stop_signal.item() == 1: + break - # --- Plot and Save Loss Curves --- - num_epochs = len(train_losses_total) - if num_epochs > 0: - epochs = range(1, num_epochs + 1) - plt.figure(figsize=(18, 5)) + # 清理和保存 + if local_rank == 0 and best_val_loss != float('inf'): + print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") + if config.distributed: + model.module.load_state_dict(torch.load('best_model_checkpoint.pt')) + torch.save(model.module.state_dict(), 'best_model.pt') + else: + model.load_state_dict(torch.load('best_model_checkpoint.pt')) + torch.save(model.state_dict(), 'best_model.pt') + print("Final best model saved to best_model.pt") - plt.subplot(1, 3, 1) - plt.plot(epochs, train_losses_ce, label='Train CE') - plt.plot(epochs, val_losses_ce, label='Val CE') - plt.title('Cross-Entropy Loss') - plt.xlabel('Epochs'); plt.ylabel('Loss') - plt.legend(); plt.grid(True) + if config.distributed: + dist.destroy_process_group() - plt.subplot(1, 3, 2) - plt.plot(epochs, train_losses_surv, label='Train Survival') - plt.plot(epochs, val_losses_surv, label='Val Survival') - plt.title('Survival Loss') - plt.xlabel('Epochs'); plt.ylabel('Loss') - plt.legend(); plt.grid(True) - - plt.subplot(1, 3, 3) - plt.plot(epochs, train_losses_total, label='Train Total') - plt.plot(epochs, val_losses_total, label='Val Total') - plt.title('Total Loss') - plt.xlabel('Epochs'); plt.ylabel('Loss') - plt.legend(); plt.grid(True) - - plt.tight_layout() - plt.savefig('loss_curves.png') - print("\nLoss curves saved to loss_curves.png") - - # Clean up DDP - ddp_cleanup() +def main(): + config = TrainConfig() + + # 设置环境变量优化 + os.environ['CUDA_LAUNCH_BLOCKING'] = '0' # 减少同步 + os.environ['NCCL_DEBUG'] = 'WARN' # 减少NCCL日志 + os.environ['NCCL_SOCKET_IFNAME'] = '^lo,docker' # 选择正确的网络接口 + + if config.distributed: + print(f"Starting distributed training with {config.world_size} GPUs") + mp.spawn( + train_worker, + args=(config,), + nprocs=config.world_size, + join=True + ) + else: + print("Starting single GPU training") + train_worker(0, config) if __name__ == '__main__': - main() + main() \ No newline at end of file