update dpp

This commit is contained in:
2025-10-16 16:46:33 +08:00
parent 6b0b86d9d0
commit b5172392cb

View File

@@ -1,16 +1,16 @@
# train_ddp.py
import os
import math
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
import math
import tqdm
import matplotlib.pyplot as plt
import os
import time
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
@@ -19,8 +19,8 @@ from utils import PatientEventDataset
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 24 # Sequence length
val_data_path = 'ukb_real_val.bin'
block_length = 24 # Sequence length
# Model parameters
n_embd = 256
@@ -31,117 +31,90 @@ class TrainConfig:
# Training parameters
max_epoch = 200
batch_size = 128
batch_size = 512 # 增大总批次大小
lr_initial = 6e-4
lr_final = 6e-5
warmup_epochs = 10
early_stopping_patience = 5
# Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 1]
# System parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Distributed training parameters
world_size = torch.cuda.device_count()
distributed = world_size > 1
def ddp_is_active():
return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1
def ddp_setup():
"""Initialize process group if launched by torchrun."""
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
dist.init_process_group(backend=backend, init_method='env://')
return True
return False
def ddp_cleanup():
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
def all_reduce_mean(value: torch.Tensor):
"""Mean across processes (no-op if not DDP)."""
if ddp_is_active():
dist.all_reduce(value, op=dist.ReduceOp.SUM)
value /= dist.get_world_size()
return value
def bcast_bool(stop: bool):
if not ddp_is_active():
return stop
t = torch.tensor(1 if stop else 0, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
dist.broadcast(t, src=0)
return bool(t.item())
def main():
config = TrainConfig()
# --- DDP setup ---
ddp_enabled = ddp_setup()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0)) if ddp_enabled else 0
world_size = int(os.environ.get("WORLD_SIZE", 1)) if ddp_enabled else 1
if torch.cuda.is_available():
# --- Main Training Function ---
def train_worker(local_rank, config):
# Initialize distributed training
if config.distributed:
dist.init_process_group(
backend='nccl',
init_method='env://',
rank=local_rank,
world_size=config.world_size
)
torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank)
print(f"Worker {local_rank} initialized on device {device}")
else:
device = torch.device(config.device)
# Seed per-rank (different but deterministic)
torch.manual_seed(1337 + global_rank)
np.random.seed(1337 + global_rank)
is_main = (global_rank == 0)
if is_main:
print(f"DDP enabled: {ddp_enabled} | world_size={world_size} | device={device}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
local_rank = 0
# --- 1. Data Loading ---
if is_main:
if local_rank == 0:
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from data (max label + 1)
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
if is_main:
if local_rank == 0:
print(f"Inferred vocabulary size: {vocab_size}")
print(f"Using {config.world_size} GPU(s) for training")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
# DDP samplers (fall back to regular shuffle when not DDP)
if ddp_enabled:
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True, drop_last=False)
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=global_rank, shuffle=False, drop_last=False)
shuffle_flag = False
# 计算每个GPU的批次大小
per_gpu_batch_size = config.batch_size // config.world_size
# 优化数据加载器参数
if config.distributed:
train_sampler = DistributedSampler(train_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=True)
val_sampler = DistributedSampler(val_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=False)
else:
train_sampler, val_sampler = None, None
shuffle_flag = True
train_sampler = None
val_sampler = None
# 增加num_workers使用persistent_workers减少进程创建开销
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=shuffle_flag,
batch_size=per_gpu_batch_size, # 使用每个GPU的批次大小
sampler=train_sampler,
num_workers=4,
shuffle=(train_sampler is None),
num_workers=8, # 增加worker数量
pin_memory=True,
drop_last=False,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
sampler=val_sampler,
num_workers=4,
pin_memory=True,
drop_last=False,
persistent_workers=True, # 保持worker进程
prefetch_factor=2 # 预取批次
)
# --- 2. Model, Optimizer, Loss ---
if is_main:
val_loader = DataLoader(
val_dataset,
batch_size=per_gpu_batch_size,
sampler=val_sampler,
shuffle=False,
num_workers=8,
pin_memory=True,
persistent_workers=True,
prefetch_factor=2
)
# --- 2. Model, Optimizer, and Loss Initialization ---
if local_rank == 0:
print(f"Initializing model on {device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
@@ -151,167 +124,175 @@ def main():
token_pdrop=config.token_pdrop
).to(device)
if is_main:
# If your model has get_num_params
try:
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
except Exception:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model initialized with {total_params/1e6:.2f}M trainable parameters.")
# 使用梯度累积来模拟更大的批次大小,减少通信频率
if config.distributed:
# 使用find_unused_parameters=False来加速
model = DDP(model, device_ids=[local_rank], output_device=local_rank,
find_unused_parameters=False)
loss_fn = CombinedLoss(config.ignored_token_ids)
if local_rank == 0:
if config.distributed:
num_params = sum(p.numel() for p in model.module.parameters() if p.requires_grad)
else:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model initialized with {num_params/1e6:.2f}M trainable parameters.")
print(f"Per GPU batch size: {per_gpu_batch_size}")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial)
# Wrap with DDP
if ddp_enabled:
model = DDP(model, device_ids=[local_rank] if device.type == 'cuda' else None,
output_device=local_rank if device.type == 'cuda' else None,
find_unused_parameters=False)
# AMP
use_amp = (device.type == 'cuda')
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
# --- 3. Training Loop ---
best_val_loss = float('inf')
patience_counter = 0
# store losses only on main to plot
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
if local_rank == 0:
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
if is_main:
if local_rank == 0:
print("Starting training...")
for epoch in range(config.max_epoch):
# Ensure different shuffles per epoch under DDP
if ddp_enabled:
if config.distributed:
train_sampler.set_epoch(epoch)
val_sampler.set_epoch(epoch)
# --- LR scheduling (same as original) ---
# --- Learning Rate Scheduling ---
if epoch < config.warmup_epochs:
lr = config.lr_initial
else:
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for pg in optimizer.param_groups:
pg['lr'] = lr
# --- Training ---
if is_main:
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
else:
pbar = train_loader # silent on non-main ranks
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# --- Training Phase ---
model.train()
train_ce_sum = torch.tensor(0.0, device=device)
train_surv_sum = torch.tensor(0.0, device=device)
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
train_steps = 0
for batch in pbar:
event_seq, time_seq = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
# 只在rank 0显示进度条
if local_rank == 0:
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
else:
pbar = train_loader
batch_start_time = time.time()
for batch_idx, (event_seq, time_seq) in enumerate(pbar):
event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True)
# Prepare inputs and targets
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast(enabled=use_amp):
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# Forward pass
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
if use_amp:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
train_ce_sum += loss_ce.detach()
train_surv_sum += loss_survival.detach()
# 梯度同步在DDP中自动处理
optimizer.step()
# 异步记录损失,避免同步阻塞
train_loss_ce_acc += loss_ce.item()
train_loss_surv_acc += loss_survival.item()
train_steps += 1
if is_main and isinstance(pbar, tqdm.tqdm):
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}',
'loss_surv': f'{loss_survival.item():.4f}',
'lr': f'{lr:.2e}'})
if local_rank == 0 and batch_idx % 10 == 0: # 每10个批次更新一次
batch_time = time.time() - batch_start_time
pbar.set_postfix({
'loss_ce': f'{loss_ce.item():.4f}',
'loss_surv': f'{loss_survival.item():.4f}',
'lr': f'{lr:.2e}',
'batch_time': f'{batch_time:.3f}s'
})
batch_start_time = time.time()
# Average train losses across ranks
train_ce_avg = train_ce_sum / max(1, train_steps)
train_surv_avg = train_surv_sum / max(1, train_steps)
train_ce_avg = all_reduce_mean(train_ce_avg)
train_surv_avg = all_reduce_mean(train_surv_avg)
train_total_avg = train_ce_avg + train_surv_avg
# 只在epoch结束时同步一次损失减少通信
if config.distributed:
# 使用all_reduce同步损失
train_loss_ce_tensor = torch.tensor([train_loss_ce_acc], device=device)
train_loss_surv_tensor = torch.tensor([train_loss_surv_acc], device=device)
train_steps_tensor = torch.tensor([train_steps], device=device)
# --- Validation ---
if is_main:
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
dist.all_reduce(train_loss_ce_tensor)
dist.all_reduce(train_loss_surv_tensor)
dist.all_reduce(train_steps_tensor)
avg_train_loss_ce = (train_loss_ce_tensor.item() / train_steps_tensor.item())
avg_train_loss_surv = (train_loss_surv_tensor.item() / train_steps_tensor.item())
else:
pbar_val = val_loader
avg_train_loss_ce = train_loss_ce_acc / train_steps
avg_train_loss_surv = train_loss_surv_acc / train_steps
# --- Validation Phase ---
model.eval()
val_ce_sum = torch.tensor(0.0, device=device)
val_surv_sum = torch.tensor(0.0, device=device)
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0
with torch.no_grad():
for batch in pbar_val:
event_seq, time_seq = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
for event_seq, time_seq in val_loader:
event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
with torch.cuda.amp.autocast(enabled=use_amp):
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_ce_sum += loss_ce.detach()
val_surv_sum += loss_survival.detach()
val_loss_ce_acc += loss_ce.item()
val_loss_surv_acc += loss_survival.item()
val_steps += 1
if is_main and isinstance(pbar_val, tqdm.tqdm):
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}',
'loss_surv': f'{loss_survival.item():.4f}'})
# 同步验证损失
if config.distributed:
val_loss_ce_tensor = torch.tensor([val_loss_ce_acc], device=device)
val_loss_surv_tensor = torch.tensor([val_loss_surv_acc], device=device)
val_steps_tensor = torch.tensor([val_steps], device=device)
# Average val losses across ranks
val_ce_avg = val_ce_sum / max(1, val_steps)
val_surv_avg = val_surv_sum / max(1, val_steps)
val_ce_avg = all_reduce_mean(val_ce_avg)
val_surv_avg = all_reduce_mean(val_surv_avg)
val_total_avg = val_ce_avg + val_surv_avg
dist.all_reduce(val_loss_ce_tensor)
dist.all_reduce(val_loss_surv_tensor)
dist.all_reduce(val_steps_tensor)
# --- Logging & Early Stopping (rank 0) ---
stop_now = False
if is_main:
print(f"Epoch {epoch+1} Summary:\n"
f" Train Loss: {float(train_total_avg):.4f} (CE: {float(train_ce_avg):.4f}, Surv: {float(train_surv_avg):.4f})\n"
f" Val Loss: {float(val_total_avg):.4f} (CE: {float(val_ce_avg):.4f}, Surv: {float(val_surv_avg):.4f})\n"
avg_val_loss_ce = (val_loss_ce_tensor.item() / val_steps_tensor.item())
avg_val_loss_surv = (val_loss_surv_tensor.item() / val_steps_tensor.item())
else:
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
# 只在rank 0进行打印和保存
if local_rank == 0:
train_losses_ce.append(avg_train_loss_ce)
train_losses_surv.append(avg_train_loss_surv)
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
val_losses_ce.append(avg_val_loss_ce)
val_losses_surv.append(avg_val_loss_surv)
val_losses_total.append(total_val_loss)
print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
f" Learning Rate: {lr:.6f}")
# Record for curves
train_losses_ce.append(float(train_ce_avg))
train_losses_surv.append(float(train_surv_avg))
train_losses_total.append(float(train_total_avg))
val_losses_ce.append(float(val_ce_avg))
val_losses_surv.append(float(val_surv_avg))
val_losses_total.append(float(val_total_avg))
if val_total_avg < best_val_loss:
best_val_loss = float(val_total_avg)
# Early stopping check
if total_val_loss < best_val_loss:
best_val_loss = total_val_loss
patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint (rank0)...")
# unwrap DDP
to_save = model.module if isinstance(model, DDP) else model
torch.save(to_save.state_dict(), 'best_model_checkpoint.pt')
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
if config.distributed:
torch.save(model.module.state_dict(), 'best_model_checkpoint.pt')
else:
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
else:
if epoch >= config.warmup_epochs:
patience_counter += 1
@@ -319,58 +300,51 @@ def main():
if patience_counter >= config.early_stopping_patience:
print("\nEarly stopping triggered due to no improvement in validation loss.")
stop_now = True
# Broadcast stop flag so all ranks exit together
stop_now = bcast_bool(stop_now)
if stop_now:
break
# --- Save Best Model at the End (rank 0 only) ---
if is_main:
if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
# unwrap for loading & final save
to_load = model.module if isinstance(model, DDP) else model
to_load.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu'))
print("Saving final best model to best_model.pt")
torch.save(to_load.state_dict(), 'best_model.pt')
if config.distributed:
stop_signal = torch.tensor(1, device=device)
dist.broadcast(stop_signal, 0)
break
else:
print("\nTraining finished. No best model to save as validation loss never improved.")
# 非rank 0进程检查停止信号
if config.distributed:
stop_signal = torch.tensor(0, device=device)
dist.broadcast(stop_signal, 0)
if stop_signal.item() == 1:
break
# --- Plot and Save Loss Curves ---
num_epochs = len(train_losses_total)
if num_epochs > 0:
epochs = range(1, num_epochs + 1)
plt.figure(figsize=(18, 5))
# 清理和保存
if local_rank == 0 and best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
if config.distributed:
model.module.load_state_dict(torch.load('best_model_checkpoint.pt'))
torch.save(model.module.state_dict(), 'best_model.pt')
else:
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
torch.save(model.state_dict(), 'best_model.pt')
print("Final best model saved to best_model.pt")
plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses_ce, label='Train CE')
plt.plot(epochs, val_losses_ce, label='Val CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Epochs'); plt.ylabel('Loss')
plt.legend(); plt.grid(True)
if config.distributed:
dist.destroy_process_group()
plt.subplot(1, 3, 2)
plt.plot(epochs, train_losses_surv, label='Train Survival')
plt.plot(epochs, val_losses_surv, label='Val Survival')
plt.title('Survival Loss')
plt.xlabel('Epochs'); plt.ylabel('Loss')
plt.legend(); plt.grid(True)
def main():
config = TrainConfig()
plt.subplot(1, 3, 3)
plt.plot(epochs, train_losses_total, label='Train Total')
plt.plot(epochs, val_losses_total, label='Val Total')
plt.title('Total Loss')
plt.xlabel('Epochs'); plt.ylabel('Loss')
plt.legend(); plt.grid(True)
# 设置环境变量优化
os.environ['CUDA_LAUNCH_BLOCKING'] = '0' # 减少同步
os.environ['NCCL_DEBUG'] = 'WARN' # 减少NCCL日志
os.environ['NCCL_SOCKET_IFNAME'] = '^lo,docker' # 选择正确的网络接口
plt.tight_layout()
plt.savefig('loss_curves.png')
print("\nLoss curves saved to loss_curves.png")
# Clean up DDP
ddp_cleanup()
if config.distributed:
print(f"Starting distributed training with {config.world_size} GPUs")
mp.spawn(
train_worker,
args=(config,),
nprocs=config.world_size,
join=True
)
else:
print("Starting single GPU training")
train_worker(0, config)
if __name__ == '__main__':
main()