update dpp

This commit is contained in:
2025-10-16 16:46:33 +08:00
parent 6b0b86d9d0
commit b5172392cb

View File

@@ -1,16 +1,16 @@
# train_ddp.py
import os
import math
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
import math
import tqdm
import matplotlib.pyplot as plt
import os
import time
from models import TimeAwareGPT2, CombinedLoss from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset from utils import PatientEventDataset
@@ -31,117 +31,90 @@ class TrainConfig:
# Training parameters # Training parameters
max_epoch = 200 max_epoch = 200
batch_size = 128 batch_size = 512 # 增大总批次大小
lr_initial = 6e-4 lr_initial = 6e-4
lr_final = 6e-5 lr_final = 6e-5
warmup_epochs = 10 warmup_epochs = 10
early_stopping_patience = 5 early_stopping_patience = 5
# Loss parameters # Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 1] ignored_token_ids = [0, 1]
# System parameters # Distributed training parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu' world_size = torch.cuda.device_count()
distributed = world_size > 1
def ddp_is_active(): # --- Main Training Function ---
return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 def train_worker(local_rank, config):
# Initialize distributed training
def ddp_setup(): if config.distributed:
"""Initialize process group if launched by torchrun.""" dist.init_process_group(
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: backend='nccl',
backend = 'nccl' if torch.cuda.is_available() else 'gloo' init_method='env://',
dist.init_process_group(backend=backend, init_method='env://') rank=local_rank,
return True world_size=config.world_size
return False )
def ddp_cleanup():
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
def all_reduce_mean(value: torch.Tensor):
"""Mean across processes (no-op if not DDP)."""
if ddp_is_active():
dist.all_reduce(value, op=dist.ReduceOp.SUM)
value /= dist.get_world_size()
return value
def bcast_bool(stop: bool):
if not ddp_is_active():
return stop
t = torch.tensor(1 if stop else 0, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
dist.broadcast(t, src=0)
return bool(t.item())
def main():
config = TrainConfig()
# --- DDP setup ---
ddp_enabled = ddp_setup()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0)) if ddp_enabled else 0
world_size = int(os.environ.get("WORLD_SIZE", 1)) if ddp_enabled else 1
if torch.cuda.is_available():
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank) device = torch.device('cuda', local_rank)
print(f"Worker {local_rank} initialized on device {device}")
else: else:
device = torch.device(config.device) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
local_rank = 0
# Seed per-rank (different but deterministic)
torch.manual_seed(1337 + global_rank)
np.random.seed(1337 + global_rank)
is_main = (global_rank == 0)
if is_main:
print(f"DDP enabled: {ddp_enabled} | world_size={world_size} | device={device}")
# --- 1. Data Loading --- # --- 1. Data Loading ---
if is_main: if local_rank == 0:
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from data (max label + 1)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
if is_main:
if local_rank == 0:
print(f"Inferred vocabulary size: {vocab_size}") print(f"Inferred vocabulary size: {vocab_size}")
print(f"Using {config.world_size} GPU(s) for training")
train_dataset = PatientEventDataset(train_data_arr, config.block_length) train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length) val_dataset = PatientEventDataset(val_data_arr, config.block_length)
# DDP samplers (fall back to regular shuffle when not DDP) # 计算每个GPU的批次大小
if ddp_enabled: per_gpu_batch_size = config.batch_size // config.world_size
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True, drop_last=False)
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=global_rank, shuffle=False, drop_last=False)
shuffle_flag = False
else:
train_sampler, val_sampler = None, None
shuffle_flag = True
# 优化数据加载器参数
if config.distributed:
train_sampler = DistributedSampler(train_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=True)
val_sampler = DistributedSampler(val_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=False)
else:
train_sampler = None
val_sampler = None
# 增加num_workers使用persistent_workers减少进程创建开销
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=config.batch_size, batch_size=per_gpu_batch_size, # 使用每个GPU的批次大小
shuffle=shuffle_flag,
sampler=train_sampler, sampler=train_sampler,
num_workers=4, shuffle=(train_sampler is None),
num_workers=8, # 增加worker数量
pin_memory=True, pin_memory=True,
drop_last=False, persistent_workers=True, # 保持worker进程
) prefetch_factor=2 # 预取批次
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
sampler=val_sampler,
num_workers=4,
pin_memory=True,
drop_last=False,
) )
# --- 2. Model, Optimizer, Loss --- val_loader = DataLoader(
if is_main: val_dataset,
batch_size=per_gpu_batch_size,
sampler=val_sampler,
shuffle=False,
num_workers=8,
pin_memory=True,
persistent_workers=True,
prefetch_factor=2
)
# --- 2. Model, Optimizer, and Loss Initialization ---
if local_rank == 0:
print(f"Initializing model on {device}...") print(f"Initializing model on {device}...")
model = TimeAwareGPT2( model = TimeAwareGPT2(
vocab_size=vocab_size, vocab_size=vocab_size,
n_embd=config.n_embd, n_embd=config.n_embd,
@@ -151,167 +124,175 @@ def main():
token_pdrop=config.token_pdrop token_pdrop=config.token_pdrop
).to(device) ).to(device)
if is_main: # 使用梯度累积来模拟更大的批次大小,减少通信频率
# If your model has get_num_params if config.distributed:
try: # 使用find_unused_parameters=False来加速
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") model = DDP(model, device_ids=[local_rank], output_device=local_rank,
except Exception: find_unused_parameters=False)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model initialized with {total_params/1e6:.2f}M trainable parameters.") if local_rank == 0:
if config.distributed:
num_params = sum(p.numel() for p in model.module.parameters() if p.requires_grad)
else:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model initialized with {num_params/1e6:.2f}M trainable parameters.")
print(f"Per GPU batch size: {per_gpu_batch_size}")
loss_fn = CombinedLoss(config.ignored_token_ids) loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial) optimizer = Adam(model.parameters(), lr=config.lr_initial)
# Wrap with DDP
if ddp_enabled:
model = DDP(model, device_ids=[local_rank] if device.type == 'cuda' else None,
output_device=local_rank if device.type == 'cuda' else None,
find_unused_parameters=False)
# AMP
use_amp = (device.type == 'cuda')
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
# --- 3. Training Loop --- # --- 3. Training Loop ---
best_val_loss = float('inf') best_val_loss = float('inf')
patience_counter = 0 patience_counter = 0
# store losses only on main to plot if local_rank == 0:
train_losses_ce, train_losses_surv, train_losses_total = [], [], [] train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], [] val_losses_ce, val_losses_surv, val_losses_total = [], [], []
if is_main: if local_rank == 0:
print("Starting training...") print("Starting training...")
for epoch in range(config.max_epoch): for epoch in range(config.max_epoch):
# Ensure different shuffles per epoch under DDP if config.distributed:
if ddp_enabled:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
val_sampler.set_epoch(epoch)
# --- LR scheduling (same as original) --- # --- Learning Rate Scheduling ---
if epoch < config.warmup_epochs: if epoch < config.warmup_epochs:
lr = config.lr_initial lr = config.lr_initial
else: else:
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for pg in optimizer.param_groups:
pg['lr'] = lr
# --- Training --- for param_group in optimizer.param_groups:
if is_main: param_group['lr'] = lr
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
else:
pbar = train_loader # silent on non-main ranks
# --- Training Phase ---
model.train() model.train()
train_ce_sum = torch.tensor(0.0, device=device) train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
train_surv_sum = torch.tensor(0.0, device=device)
train_steps = 0 train_steps = 0
for batch in pbar: # 只在rank 0显示进度条
event_seq, time_seq = batch if local_rank == 0:
event_seq = event_seq.to(device, non_blocking=True) pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
time_seq = time_seq.to(device, non_blocking=True) else:
pbar = train_loader
batch_start_time = time.time()
for batch_idx, (event_seq, time_seq) in enumerate(pbar):
event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True)
# Prepare inputs and targets
input_events = event_seq[:, :-1] input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1] input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:] target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
optimizer.zero_grad(set_to_none=True) # Forward pass
with torch.cuda.amp.autocast(enabled=use_amp):
logits = model(input_events, input_times) logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival loss = loss_ce + loss_survival
if use_amp: # Backward pass and optimization
scaler.scale(loss).backward() optimizer.zero_grad()
scaler.step(optimizer)
scaler.update()
else:
loss.backward() loss.backward()
# 梯度同步在DDP中自动处理
optimizer.step() optimizer.step()
train_ce_sum += loss_ce.detach() # 异步记录损失,避免同步阻塞
train_surv_sum += loss_survival.detach() train_loss_ce_acc += loss_ce.item()
train_loss_surv_acc += loss_survival.item()
train_steps += 1 train_steps += 1
if is_main and isinstance(pbar, tqdm.tqdm): if local_rank == 0 and batch_idx % 10 == 0: # 每10个批次更新一次
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', batch_time = time.time() - batch_start_time
pbar.set_postfix({
'loss_ce': f'{loss_ce.item():.4f}',
'loss_surv': f'{loss_survival.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}',
'lr': f'{lr:.2e}'}) 'lr': f'{lr:.2e}',
'batch_time': f'{batch_time:.3f}s'
})
batch_start_time = time.time()
# Average train losses across ranks # 只在epoch结束时同步一次损失减少通信
train_ce_avg = train_ce_sum / max(1, train_steps) if config.distributed:
train_surv_avg = train_surv_sum / max(1, train_steps) # 使用all_reduce同步损失
train_ce_avg = all_reduce_mean(train_ce_avg) train_loss_ce_tensor = torch.tensor([train_loss_ce_acc], device=device)
train_surv_avg = all_reduce_mean(train_surv_avg) train_loss_surv_tensor = torch.tensor([train_loss_surv_acc], device=device)
train_total_avg = train_ce_avg + train_surv_avg train_steps_tensor = torch.tensor([train_steps], device=device)
# --- Validation --- dist.all_reduce(train_loss_ce_tensor)
if is_main: dist.all_reduce(train_loss_surv_tensor)
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]") dist.all_reduce(train_steps_tensor)
avg_train_loss_ce = (train_loss_ce_tensor.item() / train_steps_tensor.item())
avg_train_loss_surv = (train_loss_surv_tensor.item() / train_steps_tensor.item())
else: else:
pbar_val = val_loader avg_train_loss_ce = train_loss_ce_acc / train_steps
avg_train_loss_surv = train_loss_surv_acc / train_steps
# --- Validation Phase ---
model.eval() model.eval()
val_ce_sum = torch.tensor(0.0, device=device) val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_surv_sum = torch.tensor(0.0, device=device)
val_steps = 0 val_steps = 0
with torch.no_grad(): with torch.no_grad():
for batch in pbar_val: for event_seq, time_seq in val_loader:
event_seq, time_seq = batch event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True)
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
input_events = event_seq[:, :-1] input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1] input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:] target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
with torch.cuda.amp.autocast(enabled=use_amp):
logits = model(input_events, input_times) logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_ce_sum += loss_ce.detach() val_loss_ce_acc += loss_ce.item()
val_surv_sum += loss_survival.detach() val_loss_surv_acc += loss_survival.item()
val_steps += 1 val_steps += 1
if is_main and isinstance(pbar_val, tqdm.tqdm): # 同步验证损失
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', if config.distributed:
'loss_surv': f'{loss_survival.item():.4f}'}) val_loss_ce_tensor = torch.tensor([val_loss_ce_acc], device=device)
val_loss_surv_tensor = torch.tensor([val_loss_surv_acc], device=device)
val_steps_tensor = torch.tensor([val_steps], device=device)
# Average val losses across ranks dist.all_reduce(val_loss_ce_tensor)
val_ce_avg = val_ce_sum / max(1, val_steps) dist.all_reduce(val_loss_surv_tensor)
val_surv_avg = val_surv_sum / max(1, val_steps) dist.all_reduce(val_steps_tensor)
val_ce_avg = all_reduce_mean(val_ce_avg)
val_surv_avg = all_reduce_mean(val_surv_avg) avg_val_loss_ce = (val_loss_ce_tensor.item() / val_steps_tensor.item())
val_total_avg = val_ce_avg + val_surv_avg avg_val_loss_surv = (val_loss_surv_tensor.item() / val_steps_tensor.item())
else:
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
# 只在rank 0进行打印和保存
if local_rank == 0:
train_losses_ce.append(avg_train_loss_ce)
train_losses_surv.append(avg_train_loss_surv)
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
val_losses_ce.append(avg_val_loss_ce)
val_losses_surv.append(avg_val_loss_surv)
val_losses_total.append(total_val_loss)
# --- Logging & Early Stopping (rank 0) ---
stop_now = False
if is_main:
print(f"Epoch {epoch+1} Summary: \n" print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {float(train_total_avg):.4f} (CE: {float(train_ce_avg):.4f}, Surv: {float(train_surv_avg):.4f})\n" f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
f" Val Loss: {float(val_total_avg):.4f} (CE: {float(val_ce_avg):.4f}, Surv: {float(val_surv_avg):.4f})\n" f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
f" Learning Rate: {lr:.6f}") f" Learning Rate: {lr:.6f}")
# Record for curves # Early stopping check
train_losses_ce.append(float(train_ce_avg)) if total_val_loss < best_val_loss:
train_losses_surv.append(float(train_surv_avg)) best_val_loss = total_val_loss
train_losses_total.append(float(train_total_avg))
val_losses_ce.append(float(val_ce_avg))
val_losses_surv.append(float(val_surv_avg))
val_losses_total.append(float(val_total_avg))
if val_total_avg < best_val_loss:
best_val_loss = float(val_total_avg)
patience_counter = 0 patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint (rank0)...") print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
# unwrap DDP if config.distributed:
to_save = model.module if isinstance(model, DDP) else model torch.save(model.module.state_dict(), 'best_model_checkpoint.pt')
torch.save(to_save.state_dict(), 'best_model_checkpoint.pt') else:
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
else: else:
if epoch >= config.warmup_epochs: if epoch >= config.warmup_epochs:
patience_counter += 1 patience_counter += 1
@@ -319,58 +300,51 @@ def main():
if patience_counter >= config.early_stopping_patience: if patience_counter >= config.early_stopping_patience:
print("\nEarly stopping triggered due to no improvement in validation loss.") print("\nEarly stopping triggered due to no improvement in validation loss.")
stop_now = True if config.distributed:
stop_signal = torch.tensor(1, device=device)
# Broadcast stop flag so all ranks exit together dist.broadcast(stop_signal, 0)
stop_now = bcast_bool(stop_now) break
if stop_now: else:
# 非rank 0进程检查停止信号
if config.distributed:
stop_signal = torch.tensor(0, device=device)
dist.broadcast(stop_signal, 0)
if stop_signal.item() == 1:
break break
# --- Save Best Model at the End (rank 0 only) --- # 清理和保存
if is_main: if local_rank == 0 and best_val_loss != float('inf'):
if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
# unwrap for loading & final save if config.distributed:
to_load = model.module if isinstance(model, DDP) else model model.module.load_state_dict(torch.load('best_model_checkpoint.pt'))
to_load.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu')) torch.save(model.module.state_dict(), 'best_model.pt')
print("Saving final best model to best_model.pt")
torch.save(to_load.state_dict(), 'best_model.pt')
else: else:
print("\nTraining finished. No best model to save as validation loss never improved.") model.load_state_dict(torch.load('best_model_checkpoint.pt'))
torch.save(model.state_dict(), 'best_model.pt')
print("Final best model saved to best_model.pt")
# --- Plot and Save Loss Curves --- if config.distributed:
num_epochs = len(train_losses_total) dist.destroy_process_group()
if num_epochs > 0:
epochs = range(1, num_epochs + 1)
plt.figure(figsize=(18, 5))
plt.subplot(1, 3, 1) def main():
plt.plot(epochs, train_losses_ce, label='Train CE') config = TrainConfig()
plt.plot(epochs, val_losses_ce, label='Val CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Epochs'); plt.ylabel('Loss')
plt.legend(); plt.grid(True)
plt.subplot(1, 3, 2) # 设置环境变量优化
plt.plot(epochs, train_losses_surv, label='Train Survival') os.environ['CUDA_LAUNCH_BLOCKING'] = '0' # 减少同步
plt.plot(epochs, val_losses_surv, label='Val Survival') os.environ['NCCL_DEBUG'] = 'WARN' # 减少NCCL日志
plt.title('Survival Loss') os.environ['NCCL_SOCKET_IFNAME'] = '^lo,docker' # 选择正确的网络接口
plt.xlabel('Epochs'); plt.ylabel('Loss')
plt.legend(); plt.grid(True)
plt.subplot(1, 3, 3) if config.distributed:
plt.plot(epochs, train_losses_total, label='Train Total') print(f"Starting distributed training with {config.world_size} GPUs")
plt.plot(epochs, val_losses_total, label='Val Total') mp.spawn(
plt.title('Total Loss') train_worker,
plt.xlabel('Epochs'); plt.ylabel('Loss') args=(config,),
plt.legend(); plt.grid(True) nprocs=config.world_size,
join=True
plt.tight_layout() )
plt.savefig('loss_curves.png') else:
print("\nLoss curves saved to loss_curves.png") print("Starting single GPU training")
train_worker(0, config)
# Clean up DDP
ddp_cleanup()
if __name__ == '__main__': if __name__ == '__main__':
main() main()