Files
DeepHealth/train_dpp.py
2025-10-16 16:28:52 +08:00

377 lines
14 KiB
Python

# train_ddp.py
import os
import math
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 24 # Sequence length
# Model parameters
n_embd = 256
n_layer = 8
n_head = 8
pdrop = 0.1
token_pdrop = 0.1
# Training parameters
max_epoch = 200
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
warmup_epochs = 10
early_stopping_patience = 5
# Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 1]
# System parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def ddp_is_active():
return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1
def ddp_setup():
"""Initialize process group if launched by torchrun."""
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
dist.init_process_group(backend=backend, init_method='env://')
return True
return False
def ddp_cleanup():
if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
def all_reduce_mean(value: torch.Tensor):
"""Mean across processes (no-op if not DDP)."""
if ddp_is_active():
dist.all_reduce(value, op=dist.ReduceOp.SUM)
value /= dist.get_world_size()
return value
def bcast_bool(stop: bool):
if not ddp_is_active():
return stop
t = torch.tensor(1 if stop else 0, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
dist.broadcast(t, src=0)
return bool(t.item())
def main():
config = TrainConfig()
# --- DDP setup ---
ddp_enabled = ddp_setup()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0)) if ddp_enabled else 0
world_size = int(os.environ.get("WORLD_SIZE", 1)) if ddp_enabled else 1
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
device = torch.device('cuda', local_rank)
else:
device = torch.device(config.device)
# Seed per-rank (different but deterministic)
torch.manual_seed(1337 + global_rank)
np.random.seed(1337 + global_rank)
is_main = (global_rank == 0)
if is_main:
print(f"DDP enabled: {ddp_enabled} | world_size={world_size} | device={device}")
# --- 1. Data Loading ---
if is_main:
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from data (max label + 1)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
if is_main:
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
# DDP samplers (fall back to regular shuffle when not DDP)
if ddp_enabled:
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True, drop_last=False)
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=global_rank, shuffle=False, drop_last=False)
shuffle_flag = False
else:
train_sampler, val_sampler = None, None
shuffle_flag = True
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=shuffle_flag,
sampler=train_sampler,
num_workers=4,
pin_memory=True,
drop_last=False,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
sampler=val_sampler,
num_workers=4,
pin_memory=True,
drop_last=False,
)
# --- 2. Model, Optimizer, Loss ---
if is_main:
print(f"Initializing model on {device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(device)
if is_main:
# If your model has get_num_params
try:
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
except Exception:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model initialized with {total_params/1e6:.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = Adam(model.parameters(), lr=config.lr_initial)
# Wrap with DDP
if ddp_enabled:
model = DDP(model, device_ids=[local_rank] if device.type == 'cuda' else None,
output_device=local_rank if device.type == 'cuda' else None,
find_unused_parameters=False)
# AMP
use_amp = (device.type == 'cuda')
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
# --- 3. Training Loop ---
best_val_loss = float('inf')
patience_counter = 0
# store losses only on main to plot
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
if is_main:
print("Starting training...")
for epoch in range(config.max_epoch):
# Ensure different shuffles per epoch under DDP
if ddp_enabled:
train_sampler.set_epoch(epoch)
val_sampler.set_epoch(epoch)
# --- LR scheduling (same as original) ---
if epoch < config.warmup_epochs:
lr = config.lr_initial
else:
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for pg in optimizer.param_groups:
pg['lr'] = lr
# --- Training ---
if is_main:
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
else:
pbar = train_loader # silent on non-main ranks
model.train()
train_ce_sum = torch.tensor(0.0, device=device)
train_surv_sum = torch.tensor(0.0, device=device)
train_steps = 0
for batch in pbar:
event_seq, time_seq = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast(enabled=use_amp):
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
if use_amp:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
train_ce_sum += loss_ce.detach()
train_surv_sum += loss_survival.detach()
train_steps += 1
if is_main and isinstance(pbar, tqdm.tqdm):
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}',
'loss_surv': f'{loss_survival.item():.4f}',
'lr': f'{lr:.2e}'})
# Average train losses across ranks
train_ce_avg = train_ce_sum / max(1, train_steps)
train_surv_avg = train_surv_sum / max(1, train_steps)
train_ce_avg = all_reduce_mean(train_ce_avg)
train_surv_avg = all_reduce_mean(train_surv_avg)
train_total_avg = train_ce_avg + train_surv_avg
# --- Validation ---
if is_main:
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
else:
pbar_val = val_loader
model.eval()
val_ce_sum = torch.tensor(0.0, device=device)
val_surv_sum = torch.tensor(0.0, device=device)
val_steps = 0
with torch.no_grad():
for batch in pbar_val:
event_seq, time_seq = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
with torch.cuda.amp.autocast(enabled=use_amp):
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_ce_sum += loss_ce.detach()
val_surv_sum += loss_survival.detach()
val_steps += 1
if is_main and isinstance(pbar_val, tqdm.tqdm):
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}',
'loss_surv': f'{loss_survival.item():.4f}'})
# Average val losses across ranks
val_ce_avg = val_ce_sum / max(1, val_steps)
val_surv_avg = val_surv_sum / max(1, val_steps)
val_ce_avg = all_reduce_mean(val_ce_avg)
val_surv_avg = all_reduce_mean(val_surv_avg)
val_total_avg = val_ce_avg + val_surv_avg
# --- Logging & Early Stopping (rank 0) ---
stop_now = False
if is_main:
print(f"Epoch {epoch+1} Summary:\n"
f" Train Loss: {float(train_total_avg):.4f} (CE: {float(train_ce_avg):.4f}, Surv: {float(train_surv_avg):.4f})\n"
f" Val Loss: {float(val_total_avg):.4f} (CE: {float(val_ce_avg):.4f}, Surv: {float(val_surv_avg):.4f})\n"
f" Learning Rate: {lr:.6f}")
# Record for curves
train_losses_ce.append(float(train_ce_avg))
train_losses_surv.append(float(train_surv_avg))
train_losses_total.append(float(train_total_avg))
val_losses_ce.append(float(val_ce_avg))
val_losses_surv.append(float(val_surv_avg))
val_losses_total.append(float(val_total_avg))
if val_total_avg < best_val_loss:
best_val_loss = float(val_total_avg)
patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint (rank0)...")
# unwrap DDP
to_save = model.module if isinstance(model, DDP) else model
torch.save(to_save.state_dict(), 'best_model_checkpoint.pt')
else:
if epoch >= config.warmup_epochs:
patience_counter += 1
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
if patience_counter >= config.early_stopping_patience:
print("\nEarly stopping triggered due to no improvement in validation loss.")
stop_now = True
# Broadcast stop flag so all ranks exit together
stop_now = bcast_bool(stop_now)
if stop_now:
break
# --- Save Best Model at the End (rank 0 only) ---
if is_main:
if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
# unwrap for loading & final save
to_load = model.module if isinstance(model, DDP) else model
to_load.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu'))
print("Saving final best model to best_model.pt")
torch.save(to_load.state_dict(), 'best_model.pt')
else:
print("\nTraining finished. No best model to save as validation loss never improved.")
# --- Plot and Save Loss Curves ---
num_epochs = len(train_losses_total)
if num_epochs > 0:
epochs = range(1, num_epochs + 1)
plt.figure(figsize=(18, 5))
plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses_ce, label='Train CE')
plt.plot(epochs, val_losses_ce, label='Val CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Epochs'); plt.ylabel('Loss')
plt.legend(); plt.grid(True)
plt.subplot(1, 3, 2)
plt.plot(epochs, train_losses_surv, label='Train Survival')
plt.plot(epochs, val_losses_surv, label='Val Survival')
plt.title('Survival Loss')
plt.xlabel('Epochs'); plt.ylabel('Loss')
plt.legend(); plt.grid(True)
plt.subplot(1, 3, 3)
plt.plot(epochs, train_losses_total, label='Train Total')
plt.plot(epochs, val_losses_total, label='Val Total')
plt.title('Total Loss')
plt.xlabel('Epochs'); plt.ylabel('Loss')
plt.legend(); plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves.png')
print("\nLoss curves saved to loss_curves.png")
# Clean up DDP
ddp_cleanup()
if __name__ == '__main__':
main()