377 lines
14 KiB
Python
377 lines
14 KiB
Python
# train_ddp.py
|
|
import os
|
|
import math
|
|
import numpy as np
|
|
import tqdm
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.distributed as dist
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.optim import Adam
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
|
|
from models import TimeAwareGPT2, CombinedLoss
|
|
from utils import PatientEventDataset
|
|
|
|
# --- Configuration ---
|
|
class TrainConfig:
|
|
# Data parameters
|
|
train_data_path = 'ukb_real_train.bin'
|
|
val_data_path = 'ukb_real_val.bin'
|
|
block_length = 24 # Sequence length
|
|
|
|
# Model parameters
|
|
n_embd = 256
|
|
n_layer = 8
|
|
n_head = 8
|
|
pdrop = 0.1
|
|
token_pdrop = 0.1
|
|
|
|
# Training parameters
|
|
max_epoch = 200
|
|
batch_size = 128
|
|
lr_initial = 6e-4
|
|
lr_final = 6e-5
|
|
warmup_epochs = 10
|
|
early_stopping_patience = 5
|
|
|
|
# Loss parameters
|
|
# 0 = padding, 1 = "no event"
|
|
ignored_token_ids = [0, 1]
|
|
|
|
# System parameters
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
def ddp_is_active():
|
|
return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1
|
|
|
|
def ddp_setup():
|
|
"""Initialize process group if launched by torchrun."""
|
|
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
|
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
|
|
dist.init_process_group(backend=backend, init_method='env://')
|
|
return True
|
|
return False
|
|
|
|
def ddp_cleanup():
|
|
if dist.is_initialized():
|
|
dist.barrier()
|
|
dist.destroy_process_group()
|
|
|
|
def all_reduce_mean(value: torch.Tensor):
|
|
"""Mean across processes (no-op if not DDP)."""
|
|
if ddp_is_active():
|
|
dist.all_reduce(value, op=dist.ReduceOp.SUM)
|
|
value /= dist.get_world_size()
|
|
return value
|
|
|
|
def bcast_bool(stop: bool):
|
|
if not ddp_is_active():
|
|
return stop
|
|
t = torch.tensor(1 if stop else 0, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
|
dist.broadcast(t, src=0)
|
|
return bool(t.item())
|
|
|
|
def main():
|
|
config = TrainConfig()
|
|
|
|
# --- DDP setup ---
|
|
ddp_enabled = ddp_setup()
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
global_rank = int(os.environ.get("RANK", 0)) if ddp_enabled else 0
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1)) if ddp_enabled else 1
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.set_device(local_rank)
|
|
device = torch.device('cuda', local_rank)
|
|
else:
|
|
device = torch.device(config.device)
|
|
|
|
# Seed per-rank (different but deterministic)
|
|
torch.manual_seed(1337 + global_rank)
|
|
np.random.seed(1337 + global_rank)
|
|
|
|
is_main = (global_rank == 0)
|
|
if is_main:
|
|
print(f"DDP enabled: {ddp_enabled} | world_size={world_size} | device={device}")
|
|
|
|
# --- 1. Data Loading ---
|
|
if is_main:
|
|
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
|
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
|
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
|
|
|
# Infer vocab_size from data (max label + 1)
|
|
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
|
if is_main:
|
|
print(f"Inferred vocabulary size: {vocab_size}")
|
|
|
|
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
|
|
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
|
|
|
|
# DDP samplers (fall back to regular shuffle when not DDP)
|
|
if ddp_enabled:
|
|
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True, drop_last=False)
|
|
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=global_rank, shuffle=False, drop_last=False)
|
|
shuffle_flag = False
|
|
else:
|
|
train_sampler, val_sampler = None, None
|
|
shuffle_flag = True
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=config.batch_size,
|
|
shuffle=shuffle_flag,
|
|
sampler=train_sampler,
|
|
num_workers=4,
|
|
pin_memory=True,
|
|
drop_last=False,
|
|
)
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=config.batch_size,
|
|
shuffle=False,
|
|
sampler=val_sampler,
|
|
num_workers=4,
|
|
pin_memory=True,
|
|
drop_last=False,
|
|
)
|
|
|
|
# --- 2. Model, Optimizer, Loss ---
|
|
if is_main:
|
|
print(f"Initializing model on {device}...")
|
|
model = TimeAwareGPT2(
|
|
vocab_size=vocab_size,
|
|
n_embd=config.n_embd,
|
|
n_layer=config.n_layer,
|
|
n_head=config.n_head,
|
|
pdrop=config.pdrop,
|
|
token_pdrop=config.token_pdrop
|
|
).to(device)
|
|
|
|
if is_main:
|
|
# If your model has get_num_params
|
|
try:
|
|
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
|
|
except Exception:
|
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
print(f"Model initialized with {total_params/1e6:.2f}M trainable parameters.")
|
|
|
|
loss_fn = CombinedLoss(config.ignored_token_ids)
|
|
optimizer = Adam(model.parameters(), lr=config.lr_initial)
|
|
|
|
# Wrap with DDP
|
|
if ddp_enabled:
|
|
model = DDP(model, device_ids=[local_rank] if device.type == 'cuda' else None,
|
|
output_device=local_rank if device.type == 'cuda' else None,
|
|
find_unused_parameters=False)
|
|
|
|
# AMP
|
|
use_amp = (device.type == 'cuda')
|
|
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
|
|
|
|
# --- 3. Training Loop ---
|
|
best_val_loss = float('inf')
|
|
patience_counter = 0
|
|
|
|
# store losses only on main to plot
|
|
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
|
|
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
|
|
|
|
if is_main:
|
|
print("Starting training...")
|
|
|
|
for epoch in range(config.max_epoch):
|
|
# Ensure different shuffles per epoch under DDP
|
|
if ddp_enabled:
|
|
train_sampler.set_epoch(epoch)
|
|
val_sampler.set_epoch(epoch)
|
|
|
|
# --- LR scheduling (same as original) ---
|
|
if epoch < config.warmup_epochs:
|
|
lr = config.lr_initial
|
|
else:
|
|
progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs)
|
|
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
|
|
for pg in optimizer.param_groups:
|
|
pg['lr'] = lr
|
|
|
|
# --- Training ---
|
|
if is_main:
|
|
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]")
|
|
else:
|
|
pbar = train_loader # silent on non-main ranks
|
|
|
|
model.train()
|
|
train_ce_sum = torch.tensor(0.0, device=device)
|
|
train_surv_sum = torch.tensor(0.0, device=device)
|
|
train_steps = 0
|
|
|
|
for batch in pbar:
|
|
event_seq, time_seq = batch
|
|
event_seq = event_seq.to(device, non_blocking=True)
|
|
time_seq = time_seq.to(device, non_blocking=True)
|
|
|
|
input_events = event_seq[:, :-1]
|
|
input_times = time_seq[:, :-1]
|
|
target_events = event_seq[:, 1:]
|
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
with torch.cuda.amp.autocast(enabled=use_amp):
|
|
logits = model(input_events, input_times)
|
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
|
loss = loss_ce + loss_survival
|
|
|
|
if use_amp:
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
else:
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
train_ce_sum += loss_ce.detach()
|
|
train_surv_sum += loss_survival.detach()
|
|
train_steps += 1
|
|
|
|
if is_main and isinstance(pbar, tqdm.tqdm):
|
|
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}',
|
|
'loss_surv': f'{loss_survival.item():.4f}',
|
|
'lr': f'{lr:.2e}'})
|
|
|
|
# Average train losses across ranks
|
|
train_ce_avg = train_ce_sum / max(1, train_steps)
|
|
train_surv_avg = train_surv_sum / max(1, train_steps)
|
|
train_ce_avg = all_reduce_mean(train_ce_avg)
|
|
train_surv_avg = all_reduce_mean(train_surv_avg)
|
|
train_total_avg = train_ce_avg + train_surv_avg
|
|
|
|
# --- Validation ---
|
|
if is_main:
|
|
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]")
|
|
else:
|
|
pbar_val = val_loader
|
|
|
|
model.eval()
|
|
val_ce_sum = torch.tensor(0.0, device=device)
|
|
val_surv_sum = torch.tensor(0.0, device=device)
|
|
val_steps = 0
|
|
with torch.no_grad():
|
|
for batch in pbar_val:
|
|
event_seq, time_seq = batch
|
|
event_seq = event_seq.to(device, non_blocking=True)
|
|
time_seq = time_seq.to(device, non_blocking=True)
|
|
|
|
input_events = event_seq[:, :-1]
|
|
input_times = time_seq[:, :-1]
|
|
target_events = event_seq[:, 1:]
|
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
|
|
|
with torch.cuda.amp.autocast(enabled=use_amp):
|
|
logits = model(input_events, input_times)
|
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
|
|
|
val_ce_sum += loss_ce.detach()
|
|
val_surv_sum += loss_survival.detach()
|
|
val_steps += 1
|
|
|
|
if is_main and isinstance(pbar_val, tqdm.tqdm):
|
|
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}',
|
|
'loss_surv': f'{loss_survival.item():.4f}'})
|
|
|
|
# Average val losses across ranks
|
|
val_ce_avg = val_ce_sum / max(1, val_steps)
|
|
val_surv_avg = val_surv_sum / max(1, val_steps)
|
|
val_ce_avg = all_reduce_mean(val_ce_avg)
|
|
val_surv_avg = all_reduce_mean(val_surv_avg)
|
|
val_total_avg = val_ce_avg + val_surv_avg
|
|
|
|
# --- Logging & Early Stopping (rank 0) ---
|
|
stop_now = False
|
|
if is_main:
|
|
print(f"Epoch {epoch+1} Summary:\n"
|
|
f" Train Loss: {float(train_total_avg):.4f} (CE: {float(train_ce_avg):.4f}, Surv: {float(train_surv_avg):.4f})\n"
|
|
f" Val Loss: {float(val_total_avg):.4f} (CE: {float(val_ce_avg):.4f}, Surv: {float(val_surv_avg):.4f})\n"
|
|
f" Learning Rate: {lr:.6f}")
|
|
|
|
# Record for curves
|
|
train_losses_ce.append(float(train_ce_avg))
|
|
train_losses_surv.append(float(train_surv_avg))
|
|
train_losses_total.append(float(train_total_avg))
|
|
val_losses_ce.append(float(val_ce_avg))
|
|
val_losses_surv.append(float(val_surv_avg))
|
|
val_losses_total.append(float(val_total_avg))
|
|
|
|
if val_total_avg < best_val_loss:
|
|
best_val_loss = float(val_total_avg)
|
|
patience_counter = 0
|
|
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint (rank0)...")
|
|
# unwrap DDP
|
|
to_save = model.module if isinstance(model, DDP) else model
|
|
torch.save(to_save.state_dict(), 'best_model_checkpoint.pt')
|
|
else:
|
|
if epoch >= config.warmup_epochs:
|
|
patience_counter += 1
|
|
print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}")
|
|
|
|
if patience_counter >= config.early_stopping_patience:
|
|
print("\nEarly stopping triggered due to no improvement in validation loss.")
|
|
stop_now = True
|
|
|
|
# Broadcast stop flag so all ranks exit together
|
|
stop_now = bcast_bool(stop_now)
|
|
if stop_now:
|
|
break
|
|
|
|
# --- Save Best Model at the End (rank 0 only) ---
|
|
if is_main:
|
|
if best_val_loss != float('inf'):
|
|
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
|
# unwrap for loading & final save
|
|
to_load = model.module if isinstance(model, DDP) else model
|
|
to_load.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu'))
|
|
print("Saving final best model to best_model.pt")
|
|
torch.save(to_load.state_dict(), 'best_model.pt')
|
|
else:
|
|
print("\nTraining finished. No best model to save as validation loss never improved.")
|
|
|
|
# --- Plot and Save Loss Curves ---
|
|
num_epochs = len(train_losses_total)
|
|
if num_epochs > 0:
|
|
epochs = range(1, num_epochs + 1)
|
|
plt.figure(figsize=(18, 5))
|
|
|
|
plt.subplot(1, 3, 1)
|
|
plt.plot(epochs, train_losses_ce, label='Train CE')
|
|
plt.plot(epochs, val_losses_ce, label='Val CE')
|
|
plt.title('Cross-Entropy Loss')
|
|
plt.xlabel('Epochs'); plt.ylabel('Loss')
|
|
plt.legend(); plt.grid(True)
|
|
|
|
plt.subplot(1, 3, 2)
|
|
plt.plot(epochs, train_losses_surv, label='Train Survival')
|
|
plt.plot(epochs, val_losses_surv, label='Val Survival')
|
|
plt.title('Survival Loss')
|
|
plt.xlabel('Epochs'); plt.ylabel('Loss')
|
|
plt.legend(); plt.grid(True)
|
|
|
|
plt.subplot(1, 3, 3)
|
|
plt.plot(epochs, train_losses_total, label='Train Total')
|
|
plt.plot(epochs, val_losses_total, label='Val Total')
|
|
plt.title('Total Loss')
|
|
plt.xlabel('Epochs'); plt.ylabel('Loss')
|
|
plt.legend(); plt.grid(True)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig('loss_curves.png')
|
|
print("\nLoss curves saved to loss_curves.png")
|
|
|
|
# Clean up DDP
|
|
ddp_cleanup()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|