Files
DeepHealth/train_ddp.py

365 lines
14 KiB
Python

import os
import json
import math
import argparse
from typing import Tuple
import torch
import torch.distributed as dist
from torch.optim import AdamW
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import numpy as np
import tqdm
import matplotlib.pyplot as plt
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss
from utils import PatientEventDataset
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 48
# Model parameters
n_embd = 120
n_layer = 12
n_head = 12
pdrop = 0.1
token_pdrop = 0.1
model_name = 'TimeAwareGPT2'
# Training parameters
max_epoch = 200
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
weight_decay = 2e-1
warmup_epochs = 10
early_stopping_patience = 10
betas = (0.9, 0.99)
# Loss parameters (ignored tokens)
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
def setup_ddp(backend: str | None = None):
"""Initialize torch.distributed from environment variables set by torchrun."""
if backend is None:
if torch.cuda.is_available() and os.name != 'nt':
backend = 'nccl'
else:
backend = 'gloo'
dist.init_process_group(backend=backend)
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
world_size = int(os.environ.get('WORLD_SIZE', 1))
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
device = torch.device(f'cuda:{local_rank}')
else:
device = torch.device('cpu')
return rank, world_size, local_rank, device
def cleanup_ddp():
if dist.is_initialized():
dist.destroy_process_group()
def cosine_lr(epoch: int, cfg: TrainConfig) -> float:
if epoch < cfg.warmup_epochs:
return cfg.lr_initial
progress = (epoch - cfg.warmup_epochs) / max(1, (cfg.max_epoch - cfg.warmup_epochs))
return cfg.lr_final + 0.5 * (cfg.lr_initial - cfg.lr_final) * (1 + math.cos(math.pi * progress))
def allreduce_avg(value: torch.Tensor, world_size: int) -> torch.Tensor:
"""All-reduce sum then divide by world_size."""
value = value.clone().to(torch.float64)
dist.all_reduce(value, op=dist.ReduceOp.SUM)
value /= world_size
return value.to(torch.float32)
def main():
parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model (DDP). Use torchrun to launch.')
parser.add_argument('--n_layer', type=int, default=12)
parser.add_argument('--n_embd', type=int, default=120)
parser.add_argument('--n_head', type=int, default=12)
parser.add_argument('--max_epoch', type=int, default=200)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr_initial', type=float, default=6e-4)
parser.add_argument('--lr_final', type=float, default=6e-5)
parser.add_argument('--weight_decay', type=float, default=2e-1)
parser.add_argument('--warmup_epochs', type=int, default=10)
parser.add_argument('--early_stopping_patience', type=int, default=10)
parser.add_argument('--pdrop', type=float, default=0.1)
parser.add_argument('--token_pdrop', type=float, default=0.1)
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99])
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2')
parser.add_argument('--backend', type=str, default=None, help='DDP backend (nccl/gloo). Default auto-selects.')
args = parser.parse_args()
rank, world_size, local_rank, device = setup_ddp(args.backend)
# Build config
cfg = TrainConfig()
cfg.n_layer = args.n_layer
cfg.n_embd = args.n_embd
cfg.n_head = args.n_head
cfg.max_epoch = args.max_epoch
cfg.batch_size = args.batch_size
cfg.lr_initial = args.lr_initial
cfg.lr_final = args.lr_final
cfg.weight_decay = args.weight_decay
cfg.warmup_epochs = args.warmup_epochs
cfg.early_stopping_patience = args.early_stopping_patience
cfg.pdrop = args.pdrop
cfg.token_pdrop = args.token_pdrop
cfg.betas = tuple(args.betas)
cfg.model_name = args.model
# Filenames (shared across ranks)
model_suffix = f"{cfg.model_name}_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}"
model_filename = f"best_model_{model_suffix}.pt"
checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt"
config_filename = f"config_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}.json"
# Save config only on rank 0
if rank == 0:
with open(config_filename, 'w') as f:
json.dump({k: v for k, v in vars(cfg).items() if not k.startswith('__')}, f, indent=4)
print(f"[rank 0] Configuration saved to {config_filename}")
# Load data (all ranks)
if rank == 0:
print(f"Loading data from {cfg.train_data_path} and {cfg.val_data_path}...")
train_data_arr = np.memmap(cfg.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(cfg.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
if rank == 0:
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, cfg.block_length)
val_dataset = PatientEventDataset(val_data_arr, cfg.block_length)
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=4, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, sampler=val_sampler, num_workers=4, pin_memory=torch.cuda.is_available())
# Model, loss, optimizer
model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}[cfg.model_name]
model = model_cls(
vocab_size=vocab_size,
n_embd=cfg.n_embd,
n_layer=cfg.n_layer,
n_head=cfg.n_head,
pdrop=cfg.pdrop,
token_pdrop=cfg.token_pdrop,
).to(device)
ddp_model = DDP(model, device_ids=[local_rank] if torch.cuda.is_available() else None, output_device=local_rank if torch.cuda.is_available() else None)
loss_fn = CombinedLoss(cfg.ignored_token_ids)
optimizer = AdamW(ddp_model.parameters(), lr=cfg.lr_initial, weight_decay=cfg.weight_decay, betas=cfg.betas)
best_val_loss = float('inf')
patience_counter = 0
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
if rank == 0:
print("Starting DDP training...")
for epoch in range(cfg.max_epoch):
# Update sampler epoch for shuffling
train_sampler.set_epoch(epoch)
val_sampler.set_epoch(epoch)
# Set LR
lr = cosine_lr(epoch, cfg)
for pg in optimizer.param_groups:
pg['lr'] = lr
# Train
ddp_model.train()
train_loss_ce_acc = torch.zeros(1, device=device)
train_loss_surv_acc = torch.zeros(1, device=device)
train_steps = 0
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Train]", disable=(rank != 0))
for event_seq, time_seq in pbar:
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = ddp_model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
train_loss_ce_acc += loss_ce.detach()
train_loss_surv_acc += loss_survival.detach()
train_steps += 1
if rank == 0:
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
# Aggregate train losses across ranks
if train_steps == 0:
train_steps = 1
steps_tensor = torch.tensor([train_steps], device=device, dtype=torch.float64)
dist.all_reduce(steps_tensor, op=dist.ReduceOp.SUM)
train_loss_ce_mean = allreduce_avg(train_loss_ce_acc, world_size) / (steps_tensor.item() / world_size)
train_loss_surv_mean = allreduce_avg(train_loss_surv_acc, world_size) / (steps_tensor.item() / world_size)
if rank == 0:
train_losses_ce.append(train_loss_ce_mean.item())
train_losses_surv.append(train_loss_surv_mean.item())
train_losses_total.append(train_loss_ce_mean.item() + train_loss_surv_mean.item())
# Validation
ddp_model.eval()
val_loss_ce_acc = torch.zeros(1, device=device)
val_loss_surv_acc = torch.zeros(1, device=device)
val_steps = 0
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Val]", disable=(rank != 0))
for event_seq, time_seq in pbar_val:
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = ddp_model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.detach()
val_loss_surv_acc += loss_survival.detach()
val_steps += 1
if val_steps == 0:
val_steps = 1
vsteps_tensor = torch.tensor([val_steps], device=device, dtype=torch.float64)
dist.all_reduce(vsteps_tensor, op=dist.ReduceOp.SUM)
val_loss_ce_mean = allreduce_avg(val_loss_ce_acc, world_size) / (vsteps_tensor.item() / world_size)
val_loss_surv_mean = allreduce_avg(val_loss_surv_acc, world_size) / (vsteps_tensor.item() / world_size)
total_val_loss = (val_loss_ce_mean + val_loss_surv_mean).item()
if rank == 0:
val_losses_ce.append(val_loss_ce_mean.item())
val_losses_surv.append(val_loss_surv_mean.item())
val_losses_total.append(total_val_loss)
print(
f"Epoch {epoch+1} Summary:\n"
f" Train Loss: {train_losses_total[-1]:.4f} (CE: {train_losses_ce[-1]:.4f}, Surv: {train_losses_surv[-1]:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {val_losses_ce[-1]:.4f}, Surv: {val_losses_surv[-1]:.4f})\n"
f" Learning Rate: {lr:.6f}"
)
# Early stopping on rank 0; broadcast decision
improved = total_val_loss < best_val_loss
if improved:
best_val_loss = total_val_loss
patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
torch.save(ddp_model.module.state_dict(), checkpoint_filename)
else:
if epoch >= cfg.warmup_epochs:
patience_counter += 1
print(f"Validation loss did not improve. Patience: {patience_counter}/{cfg.early_stopping_patience}")
stop_flag = torch.tensor([1 if patience_counter >= cfg.early_stopping_patience else 0], device=device)
else:
stop_flag = torch.zeros(1, device=device)
# Broadcast stop flag and best loss to all ranks
dist.broadcast(stop_flag, src=0)
if stop_flag.item() > 0:
if rank == 0:
print("\nEarly stopping triggered due to no improvement in validation loss.")
break
# Save best model at the end (rank 0)
if rank == 0 and best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
state = torch.load(checkpoint_filename, map_location='cpu')
ddp_model.module.load_state_dict(state)
print(f"Saving final best model to {model_filename}")
torch.save(ddp_model.module.state_dict(), model_filename)
# Save losses to file
losses_filename = f"losses_{model_suffix}.txt"
with open(losses_filename, 'w') as f:
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
# Plot curves
num_epochs = len(train_losses_total)
epochs = range(1, num_epochs + 1)
plt.figure(figsize=(18, 5))
plt.subplot(1, 3, 1)
plt.plot(epochs, train_losses_ce, label='Train CE')
plt.plot(epochs, val_losses_ce, label='Val CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(); plt.grid(True)
plt.subplot(1, 3, 2)
plt.plot(epochs, train_losses_surv, label='Train Survival')
plt.plot(epochs, val_losses_surv, label='Val Survival')
plt.title('Survival Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(); plt.grid(True)
plt.subplot(1, 3, 3)
plt.plot(epochs, train_losses_total, label='Train Total')
plt.plot(epochs, val_losses_total, label='Val Total')
plt.title('Total Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(); plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves.png')
print("\nLoss curves saved to loss_curves.png")
cleanup_ddp()
if __name__ == '__main__':
main()