import torch import torch.nn as nn from torch.optim import Adam from torch.utils.data import DataLoader, DistributedSampler import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP import numpy as np import math import tqdm import matplotlib.pyplot as plt import os import time from models import TimeAwareGPT2, CombinedLoss from utils import PatientEventDataset # --- Configuration --- class TrainConfig: # Data parameters train_data_path = 'ukb_real_train.bin' val_data_path = 'ukb_real_val.bin' block_length = 24 # Sequence length # Model parameters n_embd = 256 n_layer = 8 n_head = 8 pdrop = 0.1 token_pdrop = 0.1 # Training parameters max_epoch = 200 batch_size = 512 # 增大总批次大小 lr_initial = 6e-4 lr_final = 6e-5 warmup_epochs = 10 early_stopping_patience = 5 # Loss parameters ignored_token_ids = [0, 1] # Distributed training parameters world_size = torch.cuda.device_count() distributed = world_size > 1 # --- Main Training Function --- def train_worker(local_rank, config): # Initialize distributed training if config.distributed: dist.init_process_group( backend='nccl', init_method='env://', rank=local_rank, world_size=config.world_size ) torch.cuda.set_device(local_rank) device = torch.device('cuda', local_rank) print(f"Worker {local_rank} initialized on device {device}") else: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') local_rank = 0 # --- 1. Data Loading --- if local_rank == 0: print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 if local_rank == 0: print(f"Inferred vocabulary size: {vocab_size}") print(f"Using {config.world_size} GPU(s) for training") train_dataset = PatientEventDataset(train_data_arr, config.block_length) val_dataset = PatientEventDataset(val_data_arr, config.block_length) # 计算每个GPU的批次大小 per_gpu_batch_size = config.batch_size // config.world_size # 优化数据加载器参数 if config.distributed: train_sampler = DistributedSampler(train_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=True) val_sampler = DistributedSampler(val_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=False) else: train_sampler = None val_sampler = None # 增加num_workers,使用persistent_workers减少进程创建开销 train_loader = DataLoader( train_dataset, batch_size=per_gpu_batch_size, # 使用每个GPU的批次大小 sampler=train_sampler, shuffle=(train_sampler is None), num_workers=8, # 增加worker数量 pin_memory=True, persistent_workers=True, # 保持worker进程 prefetch_factor=2 # 预取批次 ) val_loader = DataLoader( val_dataset, batch_size=per_gpu_batch_size, sampler=val_sampler, shuffle=False, num_workers=8, pin_memory=True, persistent_workers=True, prefetch_factor=2 ) # --- 2. Model, Optimizer, and Loss Initialization --- if local_rank == 0: print(f"Initializing model on {device}...") model = TimeAwareGPT2( vocab_size=vocab_size, n_embd=config.n_embd, n_layer=config.n_layer, n_head=config.n_head, pdrop=config.pdrop, token_pdrop=config.token_pdrop ).to(device) # 使用梯度累积来模拟更大的批次大小,减少通信频率 if config.distributed: # 使用find_unused_parameters=False来加速 model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) if local_rank == 0: if config.distributed: num_params = sum(p.numel() for p in model.module.parameters() if p.requires_grad) else: num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model initialized with {num_params/1e6:.2f}M trainable parameters.") print(f"Per GPU batch size: {per_gpu_batch_size}") loss_fn = CombinedLoss(config.ignored_token_ids) optimizer = Adam(model.parameters(), lr=config.lr_initial) # --- 3. Training Loop --- best_val_loss = float('inf') patience_counter = 0 if local_rank == 0: train_losses_ce, train_losses_surv, train_losses_total = [], [], [] val_losses_ce, val_losses_surv, val_losses_total = [], [], [] if local_rank == 0: print("Starting training...") for epoch in range(config.max_epoch): if config.distributed: train_sampler.set_epoch(epoch) # --- Learning Rate Scheduling --- if epoch < config.warmup_epochs: lr = config.lr_initial else: progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) for param_group in optimizer.param_groups: param_group['lr'] = lr # --- Training Phase --- model.train() train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0 train_steps = 0 # 只在rank 0显示进度条 if local_rank == 0: pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") else: pbar = train_loader batch_start_time = time.time() for batch_idx, (event_seq, time_seq) in enumerate(pbar): event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True) # Prepare inputs and targets input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() # Forward pass logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss = loss_ce + loss_survival # Backward pass and optimization optimizer.zero_grad() loss.backward() # 梯度同步在DDP中自动处理 optimizer.step() # 异步记录损失,避免同步阻塞 train_loss_ce_acc += loss_ce.item() train_loss_surv_acc += loss_survival.item() train_steps += 1 if local_rank == 0 and batch_idx % 10 == 0: # 每10个批次更新一次 batch_time = time.time() - batch_start_time pbar.set_postfix({ 'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}', 'batch_time': f'{batch_time:.3f}s' }) batch_start_time = time.time() # 只在epoch结束时同步一次损失,减少通信 if config.distributed: # 使用all_reduce同步损失 train_loss_ce_tensor = torch.tensor([train_loss_ce_acc], device=device) train_loss_surv_tensor = torch.tensor([train_loss_surv_acc], device=device) train_steps_tensor = torch.tensor([train_steps], device=device) dist.all_reduce(train_loss_ce_tensor) dist.all_reduce(train_loss_surv_tensor) dist.all_reduce(train_steps_tensor) avg_train_loss_ce = (train_loss_ce_tensor.item() / train_steps_tensor.item()) avg_train_loss_surv = (train_loss_surv_tensor.item() / train_steps_tensor.item()) else: avg_train_loss_ce = train_loss_ce_acc / train_steps avg_train_loss_surv = train_loss_surv_acc / train_steps # --- Validation Phase --- model.eval() val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 val_steps = 0 with torch.no_grad(): for event_seq, time_seq in val_loader: event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True) input_events = event_seq[:, :-1] input_times = time_seq[:, :-1] target_events = event_seq[:, 1:] target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() logits = model(input_events, input_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) val_loss_ce_acc += loss_ce.item() val_loss_surv_acc += loss_survival.item() val_steps += 1 # 同步验证损失 if config.distributed: val_loss_ce_tensor = torch.tensor([val_loss_ce_acc], device=device) val_loss_surv_tensor = torch.tensor([val_loss_surv_acc], device=device) val_steps_tensor = torch.tensor([val_steps], device=device) dist.all_reduce(val_loss_ce_tensor) dist.all_reduce(val_loss_surv_tensor) dist.all_reduce(val_steps_tensor) avg_val_loss_ce = (val_loss_ce_tensor.item() / val_steps_tensor.item()) avg_val_loss_surv = (val_loss_surv_tensor.item() / val_steps_tensor.item()) else: avg_val_loss_ce = val_loss_ce_acc / val_steps avg_val_loss_surv = val_loss_surv_acc / val_steps total_val_loss = avg_val_loss_ce + avg_val_loss_surv # 只在rank 0进行打印和保存 if local_rank == 0: train_losses_ce.append(avg_train_loss_ce) train_losses_surv.append(avg_train_loss_surv) train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv) val_losses_ce.append(avg_val_loss_ce) val_losses_surv.append(avg_val_loss_surv) val_losses_total.append(total_val_loss) print(f"Epoch {epoch+1} Summary: \n" f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" f" Learning Rate: {lr:.6f}") # Early stopping check if total_val_loss < best_val_loss: best_val_loss = total_val_loss patience_counter = 0 print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") if config.distributed: torch.save(model.module.state_dict(), 'best_model_checkpoint.pt') else: torch.save(model.state_dict(), 'best_model_checkpoint.pt') else: if epoch >= config.warmup_epochs: patience_counter += 1 print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}") if patience_counter >= config.early_stopping_patience: print("\nEarly stopping triggered due to no improvement in validation loss.") if config.distributed: stop_signal = torch.tensor(1, device=device) dist.broadcast(stop_signal, 0) break else: # 非rank 0进程检查停止信号 if config.distributed: stop_signal = torch.tensor(0, device=device) dist.broadcast(stop_signal, 0) if stop_signal.item() == 1: break # 清理和保存 if local_rank == 0 and best_val_loss != float('inf'): print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") if config.distributed: model.module.load_state_dict(torch.load('best_model_checkpoint.pt')) torch.save(model.module.state_dict(), 'best_model.pt') else: model.load_state_dict(torch.load('best_model_checkpoint.pt')) torch.save(model.state_dict(), 'best_model.pt') print("Final best model saved to best_model.pt") if config.distributed: dist.destroy_process_group() def main(): config = TrainConfig() # 设置环境变量优化 os.environ['CUDA_LAUNCH_BLOCKING'] = '0' # 减少同步 os.environ['NCCL_DEBUG'] = 'WARN' # 减少NCCL日志 os.environ['NCCL_SOCKET_IFNAME'] = '^lo,docker' # 选择正确的网络接口 if config.distributed: print(f"Starting distributed training with {config.world_size} GPUs") mp.spawn( train_worker, args=(config,), nprocs=config.world_size, join=True ) else: print("Starting single GPU training") train_worker(0, config) if __name__ == '__main__': main()