Files
DeepHealth/train_dpp.py

350 lines
13 KiB
Python
Raw Normal View History

2025-10-16 16:28:52 +08:00
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler
2025-10-16 16:46:33 +08:00
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
import math
import tqdm
import matplotlib.pyplot as plt
import os
import time
2025-10-16 16:28:52 +08:00
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
2025-10-16 16:46:33 +08:00
val_data_path = 'ukb_real_val.bin'
block_length = 24 # Sequence length
2025-10-16 16:28:52 +08:00
# Model parameters
n_embd = 256
n_layer = 8
n_head = 8
pdrop = 0.1
token_pdrop = 0.1
# Training parameters
max_epoch = 200
2025-10-16 16:46:33 +08:00
batch_size = 512 # 增大总批次大小
2025-10-16 16:28:52 +08:00
lr_initial = 6e-4
lr_final = 6e-5
warmup_epochs = 10
early_stopping_patience = 5
2025-10-16 16:46:33 +08:00
2025-10-16 16:28:52 +08:00
# Loss parameters
ignored_token_ids = [0, 1]
2025-10-16 16:46:33 +08:00
# Distributed training parameters
world_size = torch.cuda.device_count()
distributed = world_size > 1
# --- Main Training Function ---
def train_worker(local_rank, config):
# Initialize distributed training
if config.distributed:
dist.init_process_group(
backend='nccl',
init_method='env://',
rank=local_rank,
world_size=config.world_size
)
2025-10-16 16:28:52 +08:00
torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank)
2025-10-16 16:46:33 +08:00
print(f"Worker {local_rank} initialized on device {device}")
2025-10-16 16:28:52 +08:00
else:
2025-10-16 16:46:33 +08:00
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
local_rank = 0
2025-10-16 16:28:52 +08:00
# --- 1. Data Loading ---
2025-10-16 16:46:33 +08:00
if local_rank == 0:
2025-10-16 16:28:52 +08:00
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
2025-10-16 16:46:33 +08:00
2025-10-16 16:28:52 +08:00
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
2025-10-16 16:46:33 +08:00
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
2025-10-16 16:28:52 +08:00
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
2025-10-16 16:46:33 +08:00
if local_rank == 0:
2025-10-16 16:28:52 +08:00
print(f"Inferred vocabulary size: {vocab_size}")
2025-10-16 16:46:33 +08:00
print(f"Using {config.world_size} GPU(s) for training")
2025-10-16 16:28:52 +08:00
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
2025-10-16 16:46:33 +08:00
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
# 计算每个GPU的批次大小
per_gpu_batch_size = config.batch_size // config.world_size
# 优化数据加载器参数
if config.distributed:
train_sampler = DistributedSampler(train_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=True)
val_sampler = DistributedSampler(val_dataset, num_replicas=config.world_size, rank=local_rank, shuffle=False)
2025-10-16 16:28:52 +08:00
else:
2025-10-16 16:46:33 +08:00
train_sampler = None
val_sampler = None
2025-10-16 16:28:52 +08:00
2025-10-16 16:46:33 +08:00
# 增加num_workers使用persistent_workers减少进程创建开销
2025-10-16 16:28:52 +08:00
train_loader = DataLoader(
2025-10-16 16:46:33 +08:00
train_dataset,
batch_size=per_gpu_batch_size, # 使用每个GPU的批次大小
2025-10-16 16:28:52 +08:00
sampler=train_sampler,
2025-10-16 16:46:33 +08:00
shuffle=(train_sampler is None),
num_workers=8, # 增加worker数量
2025-10-16 16:28:52 +08:00
pin_memory=True,
2025-10-16 16:46:33 +08:00
persistent_workers=True, # 保持worker进程
prefetch_factor=2 # 预取批次
2025-10-16 16:28:52 +08:00
)
2025-10-16 16:46:33 +08:00
2025-10-16 16:28:52 +08:00
val_loader = DataLoader(
2025-10-16 16:46:33 +08:00
val_dataset,
batch_size=per_gpu_batch_size,
2025-10-16 16:28:52 +08:00
sampler=val_sampler,
2025-10-16 16:46:33 +08:00
shuffle=False,
num_workers=8,
2025-10-16 16:28:52 +08:00
pin_memory=True,
2025-10-16 16:46:33 +08:00
persistent_workers=True,
prefetch_factor=2
2025-10-16 16:28:52 +08:00
)
2025-10-16 16:46:33 +08:00
# --- 2. Model, Optimizer, and Loss Initialization ---
if local_rank == 0:
2025-10-16 16:28:52 +08:00
print(f"Initializing model on {device}...")
2025-10-16 16:46:33 +08:00
2025-10-16 16:28:52 +08:00
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(device)
2025-10-16 16:46:33 +08:00
# 使用梯度累积来模拟更大的批次大小,减少通信频率
if config.distributed:
# 使用find_unused_parameters=False来加速
model = DDP(model, device_ids=[local_rank], output_device=local_rank,
find_unused_parameters=False)
2025-10-16 16:28:52 +08:00
2025-10-16 16:46:33 +08:00
if local_rank == 0:
if config.distributed:
num_params = sum(p.numel() for p in model.module.parameters() if p.requires_grad)
else:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model initialized with {num_params/1e6:.2f}M trainable parameters.")
print(f"Per GPU batch size: {per_gpu_batch_size}")
2025-10-16 16:28:52 +08:00
2025-10-16 16:46:33 +08:00
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial)
2025-10-16 16:28:52 +08:00
# --- 3. Training Loop ---
best_val_loss = float('inf')
patience_counter = 0
2025-10-16 16:46:33 +08:00
if local_rank == 0:
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
2025-10-16 16:28:52 +08:00
2025-10-16 16:46:33 +08:00
if local_rank == 0:
2025-10-16 16:28:52 +08:00
print("Starting training...")
2025-10-16 16:46:33 +08:00
2025-10-16 16:28:52 +08:00
for epoch in range(config.max_epoch):
2025-10-16 16:46:33 +08:00
if config.distributed:
2025-10-16 16:28:52 +08:00
train_sampler.set_epoch(epoch)
2025-10-16 16:46:33 +08:00
# --- Learning Rate Scheduling ---
2025-10-16 16:28:52 +08:00
if epoch < config.warmup_epochs:
lr = config.lr_initial
else:
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
2025-10-16 16:46:33 +08:00
for param_group in optimizer.param_groups:
param_group['lr'] = lr
2025-10-16 16:28:52 +08:00
2025-10-16 16:46:33 +08:00
# --- Training Phase ---
2025-10-16 16:28:52 +08:00
model.train()
2025-10-16 16:46:33 +08:00
train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0
2025-10-16 16:28:52 +08:00
train_steps = 0
2025-10-16 16:46:33 +08:00
# 只在rank 0显示进度条
if local_rank == 0:
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
else:
pbar = train_loader
batch_start_time = time.time()
for batch_idx, (event_seq, time_seq) in enumerate(pbar):
event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True)
2025-10-16 16:28:52 +08:00
2025-10-16 16:46:33 +08:00
# Prepare inputs and targets
2025-10-16 16:28:52 +08:00
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
2025-10-16 16:46:33 +08:00
# Forward pass
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
# 梯度同步在DDP中自动处理
optimizer.step()
# 异步记录损失,避免同步阻塞
train_loss_ce_acc += loss_ce.item()
train_loss_surv_acc += loss_survival.item()
2025-10-16 16:28:52 +08:00
train_steps += 1
2025-10-16 16:46:33 +08:00
if local_rank == 0 and batch_idx % 10 == 0: # 每10个批次更新一次
batch_time = time.time() - batch_start_time
pbar.set_postfix({
'loss_ce': f'{loss_ce.item():.4f}',
'loss_surv': f'{loss_survival.item():.4f}',
'lr': f'{lr:.2e}',
'batch_time': f'{batch_time:.3f}s'
})
batch_start_time = time.time()
# 只在epoch结束时同步一次损失减少通信
if config.distributed:
# 使用all_reduce同步损失
train_loss_ce_tensor = torch.tensor([train_loss_ce_acc], device=device)
train_loss_surv_tensor = torch.tensor([train_loss_surv_acc], device=device)
train_steps_tensor = torch.tensor([train_steps], device=device)
dist.all_reduce(train_loss_ce_tensor)
dist.all_reduce(train_loss_surv_tensor)
dist.all_reduce(train_steps_tensor)
avg_train_loss_ce = (train_loss_ce_tensor.item() / train_steps_tensor.item())
avg_train_loss_surv = (train_loss_surv_tensor.item() / train_steps_tensor.item())
2025-10-16 16:28:52 +08:00
else:
2025-10-16 16:46:33 +08:00
avg_train_loss_ce = train_loss_ce_acc / train_steps
avg_train_loss_surv = train_loss_surv_acc / train_steps
2025-10-16 16:28:52 +08:00
2025-10-16 16:46:33 +08:00
# --- Validation Phase ---
2025-10-16 16:28:52 +08:00
model.eval()
2025-10-16 16:46:33 +08:00
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
2025-10-16 16:28:52 +08:00
val_steps = 0
2025-10-16 16:46:33 +08:00
2025-10-16 16:28:52 +08:00
with torch.no_grad():
2025-10-16 16:46:33 +08:00
for event_seq, time_seq in val_loader:
event_seq, time_seq = event_seq.to(device, non_blocking=True), time_seq.to(device, non_blocking=True)
2025-10-16 16:28:52 +08:00
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
2025-10-16 16:46:33 +08:00
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.item()
val_loss_surv_acc += loss_survival.item()
2025-10-16 16:28:52 +08:00
val_steps += 1
2025-10-16 16:46:33 +08:00
# 同步验证损失
if config.distributed:
val_loss_ce_tensor = torch.tensor([val_loss_ce_acc], device=device)
val_loss_surv_tensor = torch.tensor([val_loss_surv_acc], device=device)
val_steps_tensor = torch.tensor([val_steps], device=device)
dist.all_reduce(val_loss_ce_tensor)
dist.all_reduce(val_loss_surv_tensor)
dist.all_reduce(val_steps_tensor)
avg_val_loss_ce = (val_loss_ce_tensor.item() / val_steps_tensor.item())
avg_val_loss_surv = (val_loss_surv_tensor.item() / val_steps_tensor.item())
else:
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
# 只在rank 0进行打印和保存
if local_rank == 0:
train_losses_ce.append(avg_train_loss_ce)
train_losses_surv.append(avg_train_loss_surv)
train_losses_total.append(avg_train_loss_ce + avg_train_loss_surv)
val_losses_ce.append(avg_val_loss_ce)
val_losses_surv.append(avg_val_loss_surv)
val_losses_total.append(total_val_loss)
print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
2025-10-16 16:28:52 +08:00
f" Learning Rate: {lr:.6f}")
2025-10-16 16:46:33 +08:00
# Early stopping check
if total_val_loss < best_val_loss:
best_val_loss = total_val_loss
2025-10-16 16:28:52 +08:00
patience_counter = 0
2025-10-16 16:46:33 +08:00
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
if config.distributed:
torch.save(model.module.state_dict(), 'best_model_checkpoint.pt')
else:
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
2025-10-16 16:28:52 +08:00
else:
if epoch >= config.warmup_epochs:
patience_counter += 1
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
if patience_counter >= config.early_stopping_patience:
print("\nEarly stopping triggered due to no improvement in validation loss.")
2025-10-16 16:46:33 +08:00
if config.distributed:
stop_signal = torch.tensor(1, device=device)
dist.broadcast(stop_signal, 0)
break
2025-10-16 16:28:52 +08:00
else:
2025-10-16 16:46:33 +08:00
# 非rank 0进程检查停止信号
if config.distributed:
stop_signal = torch.tensor(0, device=device)
dist.broadcast(stop_signal, 0)
if stop_signal.item() == 1:
break
# 清理和保存
if local_rank == 0 and best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
if config.distributed:
model.module.load_state_dict(torch.load('best_model_checkpoint.pt'))
torch.save(model.module.state_dict(), 'best_model.pt')
else:
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
torch.save(model.state_dict(), 'best_model.pt')
print("Final best model saved to best_model.pt")
if config.distributed:
dist.destroy_process_group()
def main():
config = TrainConfig()
# 设置环境变量优化
os.environ['CUDA_LAUNCH_BLOCKING'] = '0' # 减少同步
os.environ['NCCL_DEBUG'] = 'WARN' # 减少NCCL日志
os.environ['NCCL_SOCKET_IFNAME'] = '^lo,docker' # 选择正确的网络接口
if config.distributed:
print(f"Starting distributed training with {config.world_size} GPUs")
mp.spawn(
train_worker,
args=(config,),
nprocs=config.world_size,
join=True
)
else:
print("Starting single GPU training")
train_worker(0, config)
2025-10-16 16:28:52 +08:00
if __name__ == '__main__':
2025-10-16 16:46:33 +08:00
main()