Files
DeepHealth/train_dpp.py
2025-10-16 16:58:30 +08:00

401 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# train.py (DDP-ready)
import os
import math
import argparse
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 24 # Sequence length
# Model parameters
n_embd = 256
n_layer = 8
n_head = 8
pdrop = 0.1
token_pdrop = 0.1
# Training parameters
max_epoch = 200
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
warmup_epochs = 10
early_stopping_patience = 5
# Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 1]
# System parameters (device 将在 main() 内按 local_rank 动态设置)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def setup_distributed(backend: str = "nccl"):
"""
如果由 torchrun 启动且 WORLD_SIZE>1则初始化分布式。
返回 (is_distributed, world_size, rank, local_rank)
"""
world_size = int(os.environ.get("WORLD_SIZE", "1"))
is_distributed = world_size > 1
if is_distributed:
if not dist.is_initialized():
dist.init_process_group(backend=backend, init_method="env://")
rank = dist.get_rank()
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
else:
rank = 0
local_rank = 0
return is_distributed, world_size, rank, local_rank
def cleanup_distributed():
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
def all_reduce_mean(value: float, device, world_size: int):
"""
value 是 Python float本进程的和/均值),返回所有进程平均后的 float。
"""
tensor = torch.tensor([value], dtype=torch.float32, device=device)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
tensor /= world_size
return float(tensor.item())
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--backend", type=str, default="nccl", choices=["nccl", "gloo", "mpi"])
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
# 分布式初始化
is_dist, world_size, rank, local_rank = setup_distributed(args.backend)
# 基本环境
torch.manual_seed(args.seed + rank)
np.random.seed(args.seed + rank)
torch.backends.cudnn.benchmark = True
config = TrainConfig()
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
config.device = device
is_main = (rank == 0)
# --- 1. Data Loading ---
if is_main:
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from the data (max label + 1)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
if is_main:
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
# 分布式采样器
if is_dist:
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
else:
train_sampler = None
val_sampler = None
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=(train_sampler is None),
sampler=train_sampler,
num_workers=4,
pin_memory=True,
drop_last=False,
persistent_workers=True if 4 > 0 else False,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
sampler=val_sampler,
num_workers=4,
pin_memory=True,
drop_last=False,
persistent_workers=True if 4 > 0 else False,
)
# --- 2. Model, Optimizer, and Loss Initialization ---
if is_main:
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(device)
if is_main and hasattr(model, "get_num_params"):
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial)
# DDP 包装
if is_dist:
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank] if device.type == "cuda" else None,
output_device=local_rank if device.type == "cuda" else None,
find_unused_parameters=False,
)
# --- 3. Training Loop ---
best_val_loss = float('inf')
patience_counter = 0
# 只在主进程收集与画图
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
if is_main:
print("Starting training...")
stop_training = False
for epoch in range(config.max_epoch):
# 设置 epoch 给分布式采样器,确保跨 epoch shuffle
if is_dist:
train_sampler.set_epoch(epoch)
# --- Learning Rate Scheduling ---
if epoch < config.warmup_epochs:
lr = config.lr_initial
else:
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# --- Training Phase ---
if is_main:
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
else:
pbar = train_loader # 非主进程禁用 tqdm
model.train()
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
train_steps = 0
for batch in pbar:
event_seq, time_seq = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
# Prepare inputs and targets
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
# Forward pass
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# Backward pass and optimization
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
train_loss_ce_acc += float(loss_ce.item())
train_loss_surv_acc += float(loss_survival.item())
train_steps += 1
if is_main and isinstance(pbar, tqdm.tqdm):
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
# 进程内均值
avg_train_loss_ce_local = train_loss_ce_acc / max(train_steps, 1)
avg_train_loss_surv_local = train_loss_surv_acc / max(train_steps, 1)
# 所有进程平均
if is_dist:
avg_train_loss_ce = all_reduce_mean(avg_train_loss_ce_local, device, world_size)
avg_train_loss_surv = all_reduce_mean(avg_train_loss_surv_local, device, world_size)
else:
avg_train_loss_ce = avg_train_loss_ce_local
avg_train_loss_surv = avg_train_loss_surv_local
if is_main:
train_losses_ce.append(avg_train_loss_ce)
train_losses_surv.append(avg_train_loss_surv)
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
# --- Validation Phase ---
if is_main:
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
else:
pbar_val = val_loader
model.eval()
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0
with torch.no_grad():
for batch in pbar_val:
event_seq, time_seq = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += float(loss_ce.item())
val_loss_surv_acc += float(loss_survival.item())
val_steps += 1
if is_main and isinstance(pbar_val, tqdm.tqdm):
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
avg_val_loss_ce_local = val_loss_ce_acc / max(val_steps, 1)
avg_val_loss_surv_local = val_loss_surv_acc / max(val_steps, 1)
if is_dist:
avg_val_loss_ce = all_reduce_mean(avg_val_loss_ce_local, device, world_size)
avg_val_loss_surv = all_reduce_mean(avg_val_loss_surv_local, device, world_size)
else:
avg_val_loss_ce = avg_val_loss_ce_local
avg_val_loss_surv = avg_val_loss_surv_local
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
# 主进程打印与记录
if is_main:
print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
f" Learning Rate: {lr:.6f}")
val_losses_ce.append(avg_val_loss_ce)
val_losses_surv.append(avg_val_loss_surv)
val_losses_total.append(total_val_loss)
# --- Early Stopping Check (基于聚合后的 total_val_loss) ---
improved = False
if is_main:
if total_val_loss < best_val_loss:
best_val_loss = total_val_loss
patience_counter = 0
improved = True
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
# DDP: 保存 module.state_dict()
state_dict = model.module.state_dict() if isinstance(model, nn.parallel.DistributedDataParallel) else model.state_dict()
torch.save(state_dict, 'best_model_checkpoint.pt')
else:
if epoch >= config.warmup_epochs:
patience_counter += 1
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
stop_training = patience_counter >= config.early_stopping_patience
# 把 improved/stop 广播到所有进程,确保一致退出
if is_dist:
flag_tensor = torch.tensor([1 if stop_training else 0], device=device, dtype=torch.int32)
dist.broadcast(flag_tensor, src=0)
stop_training = bool(int(flag_tensor.item()))
if stop_training:
if is_main:
print("\nEarly stopping triggered due to no improvement in validation loss.")
break
# --- Save Best Model at the End (只主进程) ---
if is_main:
if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
# 为了易用,这里在主进程上重新构建单卡模型加载权重再保存
model_single = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to('cpu')
model_single.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu'))
print("Saving final best model to best_model.pt")
torch.save(model_single.state_dict(), 'best_model.pt')
else:
print("\nTraining finished. No best model to save as validation loss never improved.")
# --- Plot and Save Loss Curves ---
num_epochs = len(train_losses_total)
if num_epochs > 0:
epochs = range(1, num_epochs + 1)
plt.figure(figsize=(18, 5))
# Plot CE Loss
plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses_ce, label='Train CE')
plt.plot(epochs, val_losses_ce, label='Val CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Survival Loss
plt.subplot(1, 3, 2)
plt.plot(epochs, train_losses_surv, label='Train Survival')
plt.plot(epochs, val_losses_surv, label='Val Survival')
plt.title('Survival Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Total Loss
plt.subplot(1, 3, 3)
plt.plot(epochs, train_losses_total, label='Train Total')
plt.plot(epochs, val_losses_total, label='Val Total')
plt.title('Total Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves.png')
print("\nLoss curves saved to loss_curves.png")
# 清理分布式
cleanup_distributed()
if __name__ == '__main__':
main()