365 lines
14 KiB
Python
365 lines
14 KiB
Python
import os
|
|
import json
|
|
import math
|
|
import argparse
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.optim import AdamW
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
|
|
import numpy as np
|
|
import tqdm
|
|
import matplotlib.pyplot as plt
|
|
|
|
from models import TimeAwareGPT2, TimeAwareGPT2Learnable, CombinedLoss
|
|
from utils import PatientEventDataset
|
|
|
|
|
|
class TrainConfig:
|
|
# Data parameters
|
|
train_data_path = 'ukb_real_train.bin'
|
|
val_data_path = 'ukb_real_val.bin'
|
|
block_length = 48
|
|
|
|
# Model parameters
|
|
n_embd = 120
|
|
n_layer = 12
|
|
n_head = 12
|
|
pdrop = 0.1
|
|
token_pdrop = 0.1
|
|
model_name = 'TimeAwareGPT2'
|
|
|
|
# Training parameters
|
|
max_epoch = 200
|
|
batch_size = 128
|
|
lr_initial = 6e-4
|
|
lr_final = 6e-5
|
|
weight_decay = 2e-1
|
|
warmup_epochs = 10
|
|
early_stopping_patience = 10
|
|
betas = (0.9, 0.99)
|
|
|
|
# Loss parameters (ignored tokens)
|
|
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
|
|
|
|
|
def setup_ddp(backend: str | None = None):
|
|
"""Initialize torch.distributed from environment variables set by torchrun."""
|
|
if backend is None:
|
|
if torch.cuda.is_available() and os.name != 'nt':
|
|
backend = 'nccl'
|
|
else:
|
|
backend = 'gloo'
|
|
dist.init_process_group(backend=backend)
|
|
|
|
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
|
rank = int(os.environ.get('RANK', 0))
|
|
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.set_device(local_rank)
|
|
device = torch.device(f'cuda:{local_rank}')
|
|
else:
|
|
device = torch.device('cpu')
|
|
|
|
return rank, world_size, local_rank, device
|
|
|
|
|
|
def cleanup_ddp():
|
|
if dist.is_initialized():
|
|
dist.destroy_process_group()
|
|
|
|
|
|
def cosine_lr(epoch: int, cfg: TrainConfig) -> float:
|
|
if epoch < cfg.warmup_epochs:
|
|
return cfg.lr_initial
|
|
progress = (epoch - cfg.warmup_epochs) / max(1, (cfg.max_epoch - cfg.warmup_epochs))
|
|
return cfg.lr_final + 0.5 * (cfg.lr_initial - cfg.lr_final) * (1 + math.cos(math.pi * progress))
|
|
|
|
|
|
def allreduce_avg(value: torch.Tensor, world_size: int) -> torch.Tensor:
|
|
"""All-reduce sum then divide by world_size."""
|
|
value = value.clone().to(torch.float64)
|
|
dist.all_reduce(value, op=dist.ReduceOp.SUM)
|
|
value /= world_size
|
|
return value.to(torch.float32)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Train a Time-Aware GPT-2 model (DDP). Use torchrun to launch.')
|
|
parser.add_argument('--n_layer', type=int, default=12)
|
|
parser.add_argument('--n_embd', type=int, default=120)
|
|
parser.add_argument('--n_head', type=int, default=12)
|
|
parser.add_argument('--max_epoch', type=int, default=200)
|
|
parser.add_argument('--batch_size', type=int, default=128)
|
|
parser.add_argument('--lr_initial', type=float, default=6e-4)
|
|
parser.add_argument('--lr_final', type=float, default=6e-5)
|
|
parser.add_argument('--weight_decay', type=float, default=2e-1)
|
|
parser.add_argument('--warmup_epochs', type=int, default=10)
|
|
parser.add_argument('--early_stopping_patience', type=int, default=10)
|
|
parser.add_argument('--pdrop', type=float, default=0.1)
|
|
parser.add_argument('--token_pdrop', type=float, default=0.1)
|
|
parser.add_argument('--betas', type=float, nargs=2, default=[0.9, 0.99])
|
|
parser.add_argument('--model', type=str, choices=['TimeAwareGPT2', 'TimeAwareGPT2Learnable'], default='TimeAwareGPT2')
|
|
parser.add_argument('--backend', type=str, default=None, help='DDP backend (nccl/gloo). Default auto-selects.')
|
|
|
|
args = parser.parse_args()
|
|
|
|
rank, world_size, local_rank, device = setup_ddp(args.backend)
|
|
|
|
# Build config
|
|
cfg = TrainConfig()
|
|
cfg.n_layer = args.n_layer
|
|
cfg.n_embd = args.n_embd
|
|
cfg.n_head = args.n_head
|
|
cfg.max_epoch = args.max_epoch
|
|
cfg.batch_size = args.batch_size
|
|
cfg.lr_initial = args.lr_initial
|
|
cfg.lr_final = args.lr_final
|
|
cfg.weight_decay = args.weight_decay
|
|
cfg.warmup_epochs = args.warmup_epochs
|
|
cfg.early_stopping_patience = args.early_stopping_patience
|
|
cfg.pdrop = args.pdrop
|
|
cfg.token_pdrop = args.token_pdrop
|
|
cfg.betas = tuple(args.betas)
|
|
cfg.model_name = args.model
|
|
|
|
# Filenames (shared across ranks)
|
|
model_suffix = f"{cfg.model_name}_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}"
|
|
model_filename = f"best_model_{model_suffix}.pt"
|
|
checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt"
|
|
config_filename = f"config_n_embd_{cfg.n_embd}_n_layer_{cfg.n_layer}_n_head_{cfg.n_head}.json"
|
|
|
|
# Save config only on rank 0
|
|
if rank == 0:
|
|
with open(config_filename, 'w') as f:
|
|
json.dump({k: v for k, v in vars(cfg).items() if not k.startswith('__')}, f, indent=4)
|
|
print(f"[rank 0] Configuration saved to {config_filename}")
|
|
|
|
# Load data (all ranks)
|
|
if rank == 0:
|
|
print(f"Loading data from {cfg.train_data_path} and {cfg.val_data_path}...")
|
|
train_data_arr = np.memmap(cfg.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
|
val_data_arr = np.memmap(cfg.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
|
|
|
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
|
|
if rank == 0:
|
|
print(f"Inferred vocabulary size: {vocab_size}")
|
|
|
|
train_dataset = PatientEventDataset(train_data_arr, cfg.block_length)
|
|
val_dataset = PatientEventDataset(val_data_arr, cfg.block_length)
|
|
|
|
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
|
|
val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=4, pin_memory=torch.cuda.is_available())
|
|
val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, sampler=val_sampler, num_workers=4, pin_memory=torch.cuda.is_available())
|
|
|
|
# Model, loss, optimizer
|
|
model_cls = {
|
|
'TimeAwareGPT2': TimeAwareGPT2,
|
|
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
|
}[cfg.model_name]
|
|
|
|
model = model_cls(
|
|
vocab_size=vocab_size,
|
|
n_embd=cfg.n_embd,
|
|
n_layer=cfg.n_layer,
|
|
n_head=cfg.n_head,
|
|
pdrop=cfg.pdrop,
|
|
token_pdrop=cfg.token_pdrop,
|
|
).to(device)
|
|
|
|
ddp_model = DDP(model, device_ids=[local_rank] if torch.cuda.is_available() else None, output_device=local_rank if torch.cuda.is_available() else None)
|
|
|
|
loss_fn = CombinedLoss(cfg.ignored_token_ids)
|
|
optimizer = AdamW(ddp_model.parameters(), lr=cfg.lr_initial, weight_decay=cfg.weight_decay, betas=cfg.betas)
|
|
|
|
best_val_loss = float('inf')
|
|
patience_counter = 0
|
|
|
|
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
|
|
val_losses_ce, val_losses_surv, val_losses_total = [], [], []
|
|
|
|
if rank == 0:
|
|
print("Starting DDP training...")
|
|
|
|
for epoch in range(cfg.max_epoch):
|
|
# Update sampler epoch for shuffling
|
|
train_sampler.set_epoch(epoch)
|
|
val_sampler.set_epoch(epoch)
|
|
|
|
# Set LR
|
|
lr = cosine_lr(epoch, cfg)
|
|
for pg in optimizer.param_groups:
|
|
pg['lr'] = lr
|
|
|
|
# Train
|
|
ddp_model.train()
|
|
train_loss_ce_acc = torch.zeros(1, device=device)
|
|
train_loss_surv_acc = torch.zeros(1, device=device)
|
|
train_steps = 0
|
|
|
|
pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Train]", disable=(rank != 0))
|
|
for event_seq, time_seq in pbar:
|
|
event_seq = event_seq.to(device, non_blocking=True)
|
|
time_seq = time_seq.to(device, non_blocking=True)
|
|
|
|
input_events = event_seq[:, :-1]
|
|
input_times = time_seq[:, :-1]
|
|
target_events = event_seq[:, 1:]
|
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
|
|
|
logits = ddp_model(input_events, input_times)
|
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
|
loss = loss_ce + loss_survival
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
train_loss_ce_acc += loss_ce.detach()
|
|
train_loss_surv_acc += loss_survival.detach()
|
|
train_steps += 1
|
|
|
|
if rank == 0:
|
|
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
|
|
|
|
# Aggregate train losses across ranks
|
|
if train_steps == 0:
|
|
train_steps = 1
|
|
steps_tensor = torch.tensor([train_steps], device=device, dtype=torch.float64)
|
|
dist.all_reduce(steps_tensor, op=dist.ReduceOp.SUM)
|
|
train_loss_ce_mean = allreduce_avg(train_loss_ce_acc, world_size) / (steps_tensor.item() / world_size)
|
|
train_loss_surv_mean = allreduce_avg(train_loss_surv_acc, world_size) / (steps_tensor.item() / world_size)
|
|
|
|
if rank == 0:
|
|
train_losses_ce.append(train_loss_ce_mean.item())
|
|
train_losses_surv.append(train_loss_surv_mean.item())
|
|
train_losses_total.append(train_loss_ce_mean.item() + train_loss_surv_mean.item())
|
|
|
|
# Validation
|
|
ddp_model.eval()
|
|
val_loss_ce_acc = torch.zeros(1, device=device)
|
|
val_loss_surv_acc = torch.zeros(1, device=device)
|
|
val_steps = 0
|
|
|
|
with torch.no_grad():
|
|
pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{cfg.max_epoch} [Val]", disable=(rank != 0))
|
|
for event_seq, time_seq in pbar_val:
|
|
event_seq = event_seq.to(device, non_blocking=True)
|
|
time_seq = time_seq.to(device, non_blocking=True)
|
|
|
|
input_events = event_seq[:, :-1]
|
|
input_times = time_seq[:, :-1]
|
|
target_events = event_seq[:, 1:]
|
|
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
|
|
|
|
logits = ddp_model(input_events, input_times)
|
|
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
|
|
|
|
val_loss_ce_acc += loss_ce.detach()
|
|
val_loss_surv_acc += loss_survival.detach()
|
|
val_steps += 1
|
|
|
|
if val_steps == 0:
|
|
val_steps = 1
|
|
vsteps_tensor = torch.tensor([val_steps], device=device, dtype=torch.float64)
|
|
dist.all_reduce(vsteps_tensor, op=dist.ReduceOp.SUM)
|
|
val_loss_ce_mean = allreduce_avg(val_loss_ce_acc, world_size) / (vsteps_tensor.item() / world_size)
|
|
val_loss_surv_mean = allreduce_avg(val_loss_surv_acc, world_size) / (vsteps_tensor.item() / world_size)
|
|
total_val_loss = (val_loss_ce_mean + val_loss_surv_mean).item()
|
|
|
|
if rank == 0:
|
|
val_losses_ce.append(val_loss_ce_mean.item())
|
|
val_losses_surv.append(val_loss_surv_mean.item())
|
|
val_losses_total.append(total_val_loss)
|
|
|
|
print(
|
|
f"Epoch {epoch+1} Summary:\n"
|
|
f" Train Loss: {train_losses_total[-1]:.4f} (CE: {train_losses_ce[-1]:.4f}, Surv: {train_losses_surv[-1]:.4f})\n"
|
|
f" Val Loss: {total_val_loss:.4f} (CE: {val_losses_ce[-1]:.4f}, Surv: {val_losses_surv[-1]:.4f})\n"
|
|
f" Learning Rate: {lr:.6f}"
|
|
)
|
|
|
|
# Early stopping on rank 0; broadcast decision
|
|
improved = total_val_loss < best_val_loss
|
|
if improved:
|
|
best_val_loss = total_val_loss
|
|
patience_counter = 0
|
|
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
|
torch.save(ddp_model.module.state_dict(), checkpoint_filename)
|
|
else:
|
|
if epoch >= cfg.warmup_epochs:
|
|
patience_counter += 1
|
|
print(f"Validation loss did not improve. Patience: {patience_counter}/{cfg.early_stopping_patience}")
|
|
|
|
stop_flag = torch.tensor([1 if patience_counter >= cfg.early_stopping_patience else 0], device=device)
|
|
else:
|
|
stop_flag = torch.zeros(1, device=device)
|
|
|
|
# Broadcast stop flag and best loss to all ranks
|
|
dist.broadcast(stop_flag, src=0)
|
|
if stop_flag.item() > 0:
|
|
if rank == 0:
|
|
print("\nEarly stopping triggered due to no improvement in validation loss.")
|
|
break
|
|
|
|
# Save best model at the end (rank 0)
|
|
if rank == 0 and best_val_loss != float('inf'):
|
|
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
|
state = torch.load(checkpoint_filename, map_location='cpu')
|
|
ddp_model.module.load_state_dict(state)
|
|
print(f"Saving final best model to {model_filename}")
|
|
torch.save(ddp_model.module.state_dict(), model_filename)
|
|
|
|
# Save losses to file
|
|
losses_filename = f"losses_{model_suffix}.txt"
|
|
with open(losses_filename, 'w') as f:
|
|
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
|
|
for i in range(len(train_losses_total)):
|
|
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
|
|
print(f"\nLosses saved to {losses_filename}")
|
|
|
|
# Plot curves
|
|
num_epochs = len(train_losses_total)
|
|
epochs = range(1, num_epochs + 1)
|
|
plt.figure(figsize=(18, 5))
|
|
|
|
plt.subplot(1, 3, 1)
|
|
plt.plot(epochs, train_losses_ce, label='Train CE')
|
|
plt.plot(epochs, val_losses_ce, label='Val CE')
|
|
plt.title('Cross-Entropy Loss')
|
|
plt.xlabel('Epochs')
|
|
plt.ylabel('Loss')
|
|
plt.legend(); plt.grid(True)
|
|
|
|
plt.subplot(1, 3, 2)
|
|
plt.plot(epochs, train_losses_surv, label='Train Survival')
|
|
plt.plot(epochs, val_losses_surv, label='Val Survival')
|
|
plt.title('Survival Loss')
|
|
plt.xlabel('Epochs')
|
|
plt.ylabel('Loss')
|
|
plt.legend(); plt.grid(True)
|
|
|
|
plt.subplot(1, 3, 3)
|
|
plt.plot(epochs, train_losses_total, label='Train Total')
|
|
plt.plot(epochs, val_losses_total, label='Val Total')
|
|
plt.title('Total Loss')
|
|
plt.xlabel('Epochs')
|
|
plt.ylabel('Loss')
|
|
plt.legend(); plt.grid(True)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig('loss_curves.png')
|
|
print("\nLoss curves saved to loss_curves.png")
|
|
|
|
cleanup_ddp()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|