This commit is contained in:
2025-10-16 16:58:30 +08:00
parent b5172392cb
commit e3e533c9ec

View File

@@ -1,20 +1,21 @@
import torch # train.py (DDP-ready)
import torch.nn as nn import os
from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
import math import math
import argparse
import numpy as np
import tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import os
import time import torch
import torch.nn as nn
import torch.distributed as dist
from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler
from models import TimeAwareGPT2, CombinedLoss from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset from utils import PatientEventDataset
# --- Configuration --- # --- Configuration ---
class TrainConfig: class TrainConfig:
# Data parameters # Data parameters
@@ -31,90 +32,120 @@ class TrainConfig:
# Training parameters # Training parameters
max_epoch = 200 max_epoch = 200
batch_size = 512 # 增大总批次大小 batch_size = 128
lr_initial = 6e-4 lr_initial = 6e-4
lr_final = 6e-5 lr_final = 6e-5
warmup_epochs = 10 warmup_epochs = 10
early_stopping_patience = 5 early_stopping_patience = 5
# Loss parameters # Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 1] ignored_token_ids = [0, 1]
# Distributed training parameters # System parameters (device 将在 main() 内按 local_rank 动态设置)
world_size = torch.cuda.device_count() device = 'cuda' if torch.cuda.is_available() else 'cpu'
distributed = world_size > 1
# --- Main Training Function ---
def train_worker(local_rank, config): def setup_distributed(backend: str = "nccl"):
# Initialize distributed training """
if config.distributed: 如果由 torchrun 启动且 WORLD_SIZE>1则初始化分布式。
dist.init_process_group( 返回 (is_distributed, world_size, rank, local_rank)
backend='nccl', """
init_method='env://', world_size = int(os.environ.get("WORLD_SIZE", "1"))
rank=local_rank, is_distributed = world_size > 1
world_size=config.world_size if is_distributed:
) if not dist.is_initialized():
dist.init_process_group(backend=backend, init_method="env://")
rank = dist.get_rank()
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank)
print(f"Worker {local_rank} initialized on device {device}")
else: else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') rank = 0
local_rank = 0 local_rank = 0
return is_distributed, world_size, rank, local_rank
def cleanup_distributed():
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
def all_reduce_mean(value: float, device, world_size: int):
"""
value 是 Python float本进程的和/均值),返回所有进程平均后的 float。
"""
tensor = torch.tensor([value], dtype=torch.float32, device=device)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
tensor /= world_size
return float(tensor.item())
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--backend", type=str, default="nccl", choices=["nccl", "gloo", "mpi"])
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
# 分布式初始化
is_dist, world_size, rank, local_rank = setup_distributed(args.backend)
# 基本环境
torch.manual_seed(args.seed + rank)
np.random.seed(args.seed + rank)
torch.backends.cudnn.benchmark = True
config = TrainConfig()
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
config.device = device
is_main = (rank == 0)
# --- 1. Data Loading --- # --- 1. Data Loading ---
if local_rank == 0: if is_main:
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from the data (max label + 1)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
if is_main:
if local_rank == 0:
print(f"Inferred vocabulary size: {vocab_size}") print(f"Inferred vocabulary size: {vocab_size}")
print(f"Using {config.world_size} GPU(s) for training")
train_dataset = PatientEventDataset(train_data_arr, config.block_length) train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length) val_dataset = PatientEventDataset(val_data_arr, config.block_length)
# 计算每个GPU的批次大小 # 分布式采样器
per_gpu_batch_size = config.batch_size // config.world_size if is_dist:
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
# 优化数据加载器参数 val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
if config.distributed:
train_sampler = DistributedSampler(train_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=True)
val_sampler = DistributedSampler(val_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=False)
else: else:
train_sampler = None train_sampler = None
val_sampler = None val_sampler = None
# 增加num_workers使用persistent_workers减少进程创建开销
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=per_gpu_batch_size, # 使用每个GPU的批次大小 batch_size=config.batch_size,
sampler=train_sampler,
shuffle=(train_sampler is None), shuffle=(train_sampler is None),
num_workers=8, # 增加worker数量 sampler=train_sampler,
num_workers=4,
pin_memory=True, pin_memory=True,
persistent_workers=True, # 保持worker进程 drop_last=False,
prefetch_factor=2 # 预取批次 persistent_workers=True if 4 > 0 else False,
) )
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=per_gpu_batch_size, batch_size=config.batch_size,
sampler=val_sampler,
shuffle=False, shuffle=False,
num_workers=8, sampler=val_sampler,
num_workers=4,
pin_memory=True, pin_memory=True,
persistent_workers=True, drop_last=False,
prefetch_factor=2 persistent_workers=True if 4 > 0 else False,
) )
# --- 2. Model, Optimizer, and Loss Initialization --- # --- 2. Model, Optimizer, and Loss Initialization ---
if local_rank == 0: if is_main:
print(f"Initializing model on {device}...") print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2( model = TimeAwareGPT2(
vocab_size=vocab_size, vocab_size=vocab_size,
n_embd=config.n_embd, n_embd=config.n_embd,
@@ -124,36 +155,37 @@ def train_worker(local_rank, config):
token_pdrop=config.token_pdrop token_pdrop=config.token_pdrop
).to(device) ).to(device)
# 使用梯度累积来模拟更大的批次大小,减少通信频率 if is_main and hasattr(model, "get_num_params"):
if config.distributed: print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
# 使用find_unused_parameters=False来加速
model = DDP(model, device_ids=[local_rank], output_device=local_rank,
find_unused_parameters=False)
if local_rank == 0:
if config.distributed:
num_params = sum(p.numel() for p in model.module.parameters() if p.requires_grad)
else:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model initialized with {num_params/1e6:.2f}M trainable parameters.")
print(f"Per GPU batch size: {per_gpu_batch_size}")
loss_fn = CombinedLoss(config.ignored_token_ids) loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial) optimizer = Adam(model.parameters(), lr=config.lr_initial)
# DDP 包装
if is_dist:
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank] if device.type == "cuda" else None,
output_device=local_rank if device.type == "cuda" else None,
find_unused_parameters=False,
)
# --- 3. Training Loop --- # --- 3. Training Loop ---
best_val_loss = float('inf') best_val_loss = float('inf')
patience_counter = 0 patience_counter = 0
if local_rank == 0: # 只在主进程收集与画图
train_losses_ce, train_losses_surv, train_losses_total = [], [], [] train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], [] val_losses_ce, val_losses_surv, val_losses_total = [], [], []
if local_rank == 0: if is_main:
print("Starting training...") print("Starting training...")
stop_training = False
for epoch in range(config.max_epoch): for epoch in range(config.max_epoch):
if config.distributed: # 设置 epoch 给分布式采样器,确保跨 epoch shuffle
if is_dist:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
# --- Learning Rate Scheduling --- # --- Learning Rate Scheduling ---
@@ -162,24 +194,23 @@ def train_worker(local_rank, config):
else: else:
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = lr
# --- Training Phase --- # --- Training Phase ---
if is_main:
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
else:
pbar = train_loader # 非主进程禁用 tqdm
model.train() model.train()
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0 train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
train_steps = 0 train_steps = 0
# 只在rank 0显示进度条 for batch in pbar:
if local_rank == 0: event_seq, time_seq = batch
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") event_seq = event_seq.to(device, non_blocking=True)
else: time_seq = time_seq.to(device, non_blocking=True)
pbar = train_loader
batch_start_time = time.time()
for batch_idx, (event_seq, time_seq) in enumerate(pbar):
event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True)
# Prepare inputs and targets # Prepare inputs and targets
input_events = event_seq[:, :-1] input_events = event_seq[:, :-1]
@@ -193,52 +224,49 @@ def train_worker(local_rank, config):
loss = loss_ce + loss_survival loss = loss_ce + loss_survival
# Backward pass and optimization # Backward pass and optimization
optimizer.zero_grad() optimizer.zero_grad(set_to_none=True)
loss.backward() loss.backward()
# 梯度同步在DDP中自动处理
optimizer.step() optimizer.step()
# 异步记录损失,避免同步阻塞 train_loss_ce_acc += float(loss_ce.item())
train_loss_ce_acc += loss_ce.item() train_loss_surv_acc += float(loss_survival.item())
train_loss_surv_acc += loss_survival.item()
train_steps += 1 train_steps += 1
if local_rank == 0 and batch_idx % 10 == 0: # 每10个批次更新一次 if is_main and isinstance(pbar, tqdm.tqdm):
batch_time = time.time() - batch_start_time pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
pbar.set_postfix({
'loss_ce': f'{loss_ce.item():.4f}',
'loss_surv': f'{loss_survival.item():.4f}',
'lr': f'{lr:.2e}',
'batch_time': f'{batch_time:.3f}s'
})
batch_start_time = time.time()
# 只在epoch结束时同步一次损失减少通信 # 进程内均值
if config.distributed: avg_train_loss_ce_local = train_loss_ce_acc / max(train_steps, 1)
# 使用all_reduce同步损失 avg_train_loss_surv_local = train_loss_surv_acc / max(train_steps, 1)
train_loss_ce_tensor = torch.tensor([train_loss_ce_acc], device=device)
train_loss_surv_tensor = torch.tensor([train_loss_surv_acc], device=device)
train_steps_tensor = torch.tensor([train_steps], device=device)
dist.all_reduce(train_loss_ce_tensor) # 所有进程平均
dist.all_reduce(train_loss_surv_tensor) if is_dist:
dist.all_reduce(train_steps_tensor) avg_train_loss_ce = all_reduce_mean(avg_train_loss_ce_local, device, world_size)
avg_train_loss_surv = all_reduce_mean(avg_train_loss_surv_local, device, world_size)
avg_train_loss_ce = (train_loss_ce_tensor.item() / train_steps_tensor.item())
avg_train_loss_surv = (train_loss_surv_tensor.item() / train_steps_tensor.item())
else: else:
avg_train_loss_ce = train_loss_ce_acc / train_steps avg_train_loss_ce = avg_train_loss_ce_local
avg_train_loss_surv = train_loss_surv_acc / train_steps avg_train_loss_surv = avg_train_loss_surv_local
if is_main:
train_losses_ce.append(avg_train_loss_ce)
train_losses_surv.append(avg_train_loss_surv)
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
# --- Validation Phase --- # --- Validation Phase ---
if is_main:
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
else:
pbar_val = val_loader
model.eval() model.eval()
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0 val_steps = 0
with torch.no_grad(): with torch.no_grad():
for event_seq, time_seq in val_loader: for batch in pbar_val:
event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True) event_seq, time_seq = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
input_events = event_seq[:, :-1] input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1] input_times = time_seq[:, :-1]
@@ -248,103 +276,125 @@ def train_worker(local_rank, config):
logits = model(input_events, input_times) logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.item() val_loss_ce_acc += float(loss_ce.item())
val_loss_surv_acc += loss_survival.item() val_loss_surv_acc += float(loss_survival.item())
val_steps += 1 val_steps += 1
# 同步验证损失 if is_main and isinstance(pbar_val, tqdm.tqdm):
if config.distributed: pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
val_loss_ce_tensor = torch.tensor([val_loss_ce_acc], device=device)
val_loss_surv_tensor = torch.tensor([val_loss_surv_acc], device=device)
val_steps_tensor = torch.tensor([val_steps], device=device)
dist.all_reduce(val_loss_ce_tensor) avg_val_loss_ce_local = val_loss_ce_acc / max(val_steps, 1)
dist.all_reduce(val_loss_surv_tensor) avg_val_loss_surv_local = val_loss_surv_acc / max(val_steps, 1)
dist.all_reduce(val_steps_tensor)
avg_val_loss_ce = (val_loss_ce_tensor.item() / val_steps_tensor.item()) if is_dist:
avg_val_loss_surv = (val_loss_surv_tensor.item() / val_steps_tensor.item()) avg_val_loss_ce = all_reduce_mean(avg_val_loss_ce_local, device, world_size)
avg_val_loss_surv = all_reduce_mean(avg_val_loss_surv_local, device, world_size)
else: else:
avg_val_loss_ce = val_loss_ce_acc / val_steps avg_val_loss_ce = avg_val_loss_ce_local
avg_val_loss_surv = val_loss_surv_acc / val_steps avg_val_loss_surv = avg_val_loss_surv_local
total_val_loss = avg_val_loss_ce + avg_val_loss_surv total_val_loss = avg_val_loss_ce + avg_val_loss_surv
# 只在rank 0进行打印和保存 # 主进程打印与记录
if local_rank == 0: if is_main:
train_losses_ce.append(avg_train_loss_ce)
train_losses_surv.append(avg_train_loss_surv)
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
val_losses_ce.append(avg_val_loss_ce)
val_losses_surv.append(avg_val_loss_surv)
val_losses_total.append(total_val_loss)
print(f"Epoch {epoch+1} Summary: \n" print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
f" Learning Rate: {lr:.6f}") f" Learning Rate: {lr:.6f}")
val_losses_ce.append(avg_val_loss_ce)
val_losses_surv.append(avg_val_loss_surv)
val_losses_total.append(total_val_loss)
# Early stopping check # --- Early Stopping Check (基于聚合后的 total_val_loss) ---
improved = False
if is_main:
if total_val_loss < best_val_loss: if total_val_loss < best_val_loss:
best_val_loss = total_val_loss best_val_loss = total_val_loss
patience_counter = 0 patience_counter = 0
improved = True
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
if config.distributed: # DDP: 保存 module.state_dict()
torch.save(model.module.state_dict(), 'best_model_checkpoint.pt') state_dict = model.module.state_dict() if isinstance(model, nn.parallel.DistributedDataParallel) else model.state_dict()
else: torch.save(state_dict, 'best_model_checkpoint.pt')
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
else: else:
if epoch >= config.warmup_epochs: if epoch >= config.warmup_epochs:
patience_counter += 1 patience_counter += 1
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}") print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
stop_training = patience_counter >= config.early_stopping_patience
if patience_counter >= config.early_stopping_patience: # 把 improved/stop 广播到所有进程,确保一致退出
if is_dist:
flag_tensor = torch.tensor([1 if stop_training else 0], device=device, dtype=torch.int32)
dist.broadcast(flag_tensor, src=0)
stop_training = bool(int(flag_tensor.item()))
if stop_training:
if is_main:
print("\nEarly stopping triggered due to no improvement in validation loss.") print("\nEarly stopping triggered due to no improvement in validation loss.")
if config.distributed: break
stop_signal = torch.tensor(1, device=device)
dist.broadcast(stop_signal, 0) # --- Save Best Model at the End (只主进程) ---
break if is_main:
if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
# 为了易用,这里在主进程上重新构建单卡模型加载权重再保存
model_single = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to('cpu')
model_single.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu'))
print("Saving final best model to best_model.pt")
torch.save(model_single.state_dict(), 'best_model.pt')
else: else:
# 非rank 0进程检查停止信号 print("\nTraining finished. No best model to save as validation loss never improved.")
if config.distributed:
stop_signal = torch.tensor(0, device=device)
dist.broadcast(stop_signal, 0)
if stop_signal.item() == 1:
break
# 清理和保存 # --- Plot and Save Loss Curves ---
if local_rank == 0 and best_val_loss != float('inf'): num_epochs = len(train_losses_total)
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") if num_epochs > 0:
if config.distributed: epochs = range(1, num_epochs + 1)
model.module.load_state_dict(torch.load('best_model_checkpoint.pt')) plt.figure(figsize=(18, 5))
torch.save(model.module.state_dict(), 'best_model.pt')
else:
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
torch.save(model.state_dict(), 'best_model.pt')
print("Final best model saved to best_model.pt")
if config.distributed: # Plot CE Loss
dist.destroy_process_group() plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses_ce, label='Train CE')
plt.plot(epochs, val_losses_ce, label='Val CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
def main(): # Plot Survival Loss
config = TrainConfig() plt.subplot(1, 3, 2)
plt.plot(epochs, train_losses_surv, label='Train Survival')
plt.plot(epochs, val_losses_surv, label='Val Survival')
plt.title('Survival Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# 设置环境变量优化 # Plot Total Loss
os.environ['CUDA_LAUNCH_BLOCKING'] = '0' # 减少同步 plt.subplot(1, 3, 3)
os.environ['NCCL_DEBUG'] = 'WARN' # 减少NCCL日志 plt.plot(epochs, train_losses_total, label='Train Total')
os.environ['NCCL_SOCKET_IFNAME'] = '^lo,docker' # 选择正确的网络接口 plt.plot(epochs, val_losses_total, label='Val Total')
plt.title('Total Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves.png')
print("\nLoss curves saved to loss_curves.png")
# 清理分布式
cleanup_distributed()
if config.distributed:
print(f"Starting distributed training with {config.world_size} GPUs")
mp.spawn(
train_worker,
args=(config,),
nprocs=config.world_size,
join=True
)
else:
print("Starting single GPU training")
train_worker(0, config)
if __name__ == '__main__': if __name__ == '__main__':
main() main()