diff --git a/train_dpp.py b/train_dpp.py new file mode 100644 index 0000000..5f18a36 --- /dev/null +++ b/train_dpp.py @@ -0,0 +1,376 @@ +# train_ddp.py +import os +import math +import numpy as np +import tqdm +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Adam +from torch.utils.data import DataLoader, DistributedSampler + +from models import TimeAwareGPT2, CombinedLoss +from utils import PatientEventDataset + +# --- Configuration --- +class TrainConfig: + # Data parameters + train_data_path = 'ukb_real_train.bin' + val_data_path = 'ukb_real_val.bin' + block_length = 24 # Sequence length + + # Model parameters + n_embd = 256 + n_layer = 8 + n_head = 8 + pdrop = 0.1 + token_pdrop = 0.1 + + # Training parameters + max_epoch = 200 + batch_size = 128 + lr_initial = 6e-4 + lr_final = 6e-5 + warmup_epochs = 10 + early_stopping_patience = 5 + + # Loss parameters + # 0 = padding, 1 = "no event" + ignored_token_ids = [0, 1] + + # System parameters + device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def ddp_is_active(): + return dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 + +def ddp_setup(): + """Initialize process group if launched by torchrun.""" + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + backend = 'nccl' if torch.cuda.is_available() else 'gloo' + dist.init_process_group(backend=backend, init_method='env://') + return True + return False + +def ddp_cleanup(): + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + +def all_reduce_mean(value: torch.Tensor): + """Mean across processes (no-op if not DDP).""" + if ddp_is_active(): + dist.all_reduce(value, op=dist.ReduceOp.SUM) + value /= dist.get_world_size() + return value + +def bcast_bool(stop: bool): + if not ddp_is_active(): + return stop + t = torch.tensor(1 if stop else 0, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + dist.broadcast(t, src=0) + return bool(t.item()) + +def main(): + config = TrainConfig() + + # --- DDP setup --- + ddp_enabled = ddp_setup() + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + global_rank = int(os.environ.get("RANK", 0)) if ddp_enabled else 0 + world_size = int(os.environ.get("WORLD_SIZE", 1)) if ddp_enabled else 1 + + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + device = torch.device('cuda', local_rank) + else: + device = torch.device(config.device) + + # Seed per-rank (different but deterministic) + torch.manual_seed(1337 + global_rank) + np.random.seed(1337 + global_rank) + + is_main = (global_rank == 0) + if is_main: + print(f"DDP enabled: {ddp_enabled} | world_size={world_size} | device={device}") + + # --- 1. Data Loading --- + if is_main: + print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") + train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + + # Infer vocab_size from data (max label + 1) + vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 + if is_main: + print(f"Inferred vocabulary size: {vocab_size}") + + train_dataset = PatientEventDataset(train_data_arr, config.block_length) + val_dataset = PatientEventDataset(val_data_arr, config.block_length) + + # DDP samplers (fall back to regular shuffle when not DDP) + if ddp_enabled: + train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True, drop_last=False) + val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=global_rank, shuffle=False, drop_last=False) + shuffle_flag = False + else: + train_sampler, val_sampler = None, None + shuffle_flag = True + + train_loader = DataLoader( + train_dataset, + batch_size=config.batch_size, + shuffle=shuffle_flag, + sampler=train_sampler, + num_workers=4, + pin_memory=True, + drop_last=False, + ) + val_loader = DataLoader( + val_dataset, + batch_size=config.batch_size, + shuffle=False, + sampler=val_sampler, + num_workers=4, + pin_memory=True, + drop_last=False, + ) + + # --- 2. Model, Optimizer, Loss --- + if is_main: + print(f"Initializing model on {device}...") + model = TimeAwareGPT2( + vocab_size=vocab_size, + n_embd=config.n_embd, + n_layer=config.n_layer, + n_head=config.n_head, + pdrop=config.pdrop, + token_pdrop=config.token_pdrop + ).to(device) + + if is_main: + # If your model has get_num_params + try: + print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") + except Exception: + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Model initialized with {total_params/1e6:.2f}M trainable parameters.") + + loss_fn = CombinedLoss(config.ignored_token_ids) + optimizer = Adam(model.parameters(), lr=config.lr_initial) + + # Wrap with DDP + if ddp_enabled: + model = DDP(model, device_ids=[local_rank] if device.type == 'cuda' else None, + output_device=local_rank if device.type == 'cuda' else None, + find_unused_parameters=False) + + # AMP + use_amp = (device.type == 'cuda') + scaler = torch.cuda.amp.GradScaler(enabled=use_amp) + + # --- 3. Training Loop --- + best_val_loss = float('inf') + patience_counter = 0 + + # store losses only on main to plot + train_losses_ce, train_losses_surv, train_losses_total = [], [], [] + val_losses_ce, val_losses_surv, val_losses_total = [], [], [] + + if is_main: + print("Starting training...") + + for epoch in range(config.max_epoch): + # Ensure different shuffles per epoch under DDP + if ddp_enabled: + train_sampler.set_epoch(epoch) + val_sampler.set_epoch(epoch) + + # --- LR scheduling (same as original) --- + if epoch < config.warmup_epochs: + lr = config.lr_initial + else: + progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) + lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) + for pg in optimizer.param_groups: + pg['lr'] = lr + + # --- Training --- + if is_main: + pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") + else: + pbar = train_loader # silent on non-main ranks + + model.train() + train_ce_sum = torch.tensor(0.0, device=device) + train_surv_sum = torch.tensor(0.0, device=device) + train_steps = 0 + + for batch in pbar: + event_seq, time_seq = batch + event_seq = event_seq.to(device, non_blocking=True) + time_seq = time_seq.to(device, non_blocking=True) + + input_events = event_seq[:, :-1] + input_times = time_seq[:, :-1] + target_events = event_seq[:, 1:] + target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() + + optimizer.zero_grad(set_to_none=True) + with torch.cuda.amp.autocast(enabled=use_amp): + logits = model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + loss = loss_ce + loss_survival + + if use_amp: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + + train_ce_sum += loss_ce.detach() + train_surv_sum += loss_survival.detach() + train_steps += 1 + + if is_main and isinstance(pbar, tqdm.tqdm): + pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', + 'loss_surv': f'{loss_survival.item():.4f}', + 'lr': f'{lr:.2e}'}) + + # Average train losses across ranks + train_ce_avg = train_ce_sum / max(1, train_steps) + train_surv_avg = train_surv_sum / max(1, train_steps) + train_ce_avg = all_reduce_mean(train_ce_avg) + train_surv_avg = all_reduce_mean(train_surv_avg) + train_total_avg = train_ce_avg + train_surv_avg + + # --- Validation --- + if is_main: + pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]") + else: + pbar_val = val_loader + + model.eval() + val_ce_sum = torch.tensor(0.0, device=device) + val_surv_sum = torch.tensor(0.0, device=device) + val_steps = 0 + with torch.no_grad(): + for batch in pbar_val: + event_seq, time_seq = batch + event_seq = event_seq.to(device, non_blocking=True) + time_seq = time_seq.to(device, non_blocking=True) + + input_events = event_seq[:, :-1] + input_times = time_seq[:, :-1] + target_events = event_seq[:, 1:] + target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() + + with torch.cuda.amp.autocast(enabled=use_amp): + logits = model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + + val_ce_sum += loss_ce.detach() + val_surv_sum += loss_survival.detach() + val_steps += 1 + + if is_main and isinstance(pbar_val, tqdm.tqdm): + pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', + 'loss_surv': f'{loss_survival.item():.4f}'}) + + # Average val losses across ranks + val_ce_avg = val_ce_sum / max(1, val_steps) + val_surv_avg = val_surv_sum / max(1, val_steps) + val_ce_avg = all_reduce_mean(val_ce_avg) + val_surv_avg = all_reduce_mean(val_surv_avg) + val_total_avg = val_ce_avg + val_surv_avg + + # --- Logging & Early Stopping (rank 0) --- + stop_now = False + if is_main: + print(f"Epoch {epoch+1} Summary:\n" + f" Train Loss: {float(train_total_avg):.4f} (CE: {float(train_ce_avg):.4f}, Surv: {float(train_surv_avg):.4f})\n" + f" Val Loss: {float(val_total_avg):.4f} (CE: {float(val_ce_avg):.4f}, Surv: {float(val_surv_avg):.4f})\n" + f" Learning Rate: {lr:.6f}") + + # Record for curves + train_losses_ce.append(float(train_ce_avg)) + train_losses_surv.append(float(train_surv_avg)) + train_losses_total.append(float(train_total_avg)) + val_losses_ce.append(float(val_ce_avg)) + val_losses_surv.append(float(val_surv_avg)) + val_losses_total.append(float(val_total_avg)) + + if val_total_avg < best_val_loss: + best_val_loss = float(val_total_avg) + patience_counter = 0 + print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint (rank0)...") + # unwrap DDP + to_save = model.module if isinstance(model, DDP) else model + torch.save(to_save.state_dict(), 'best_model_checkpoint.pt') + else: + if epoch >= config.warmup_epochs: + patience_counter += 1 + print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}") + + if patience_counter >= config.early_stopping_patience: + print("\nEarly stopping triggered due to no improvement in validation loss.") + stop_now = True + + # Broadcast stop flag so all ranks exit together + stop_now = bcast_bool(stop_now) + if stop_now: + break + + # --- Save Best Model at the End (rank 0 only) --- + if is_main: + if best_val_loss != float('inf'): + print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") + # unwrap for loading & final save + to_load = model.module if isinstance(model, DDP) else model + to_load.load_state_dict(torch.load('best_model_checkpoint.pt', map_location='cpu')) + print("Saving final best model to best_model.pt") + torch.save(to_load.state_dict(), 'best_model.pt') + else: + print("\nTraining finished. No best model to save as validation loss never improved.") + + # --- Plot and Save Loss Curves --- + num_epochs = len(train_losses_total) + if num_epochs > 0: + epochs = range(1, num_epochs + 1) + plt.figure(figsize=(18, 5)) + + plt.subplot(1, 3, 1) + plt.plot(epochs, train_losses_ce, label='Train CE') + plt.plot(epochs, val_losses_ce, label='Val CE') + plt.title('Cross-Entropy Loss') + plt.xlabel('Epochs'); plt.ylabel('Loss') + plt.legend(); plt.grid(True) + + plt.subplot(1, 3, 2) + plt.plot(epochs, train_losses_surv, label='Train Survival') + plt.plot(epochs, val_losses_surv, label='Val Survival') + plt.title('Survival Loss') + plt.xlabel('Epochs'); plt.ylabel('Loss') + plt.legend(); plt.grid(True) + + plt.subplot(1, 3, 3) + plt.plot(epochs, train_losses_total, label='Train Total') + plt.plot(epochs, val_losses_total, label='Val Total') + plt.title('Total Loss') + plt.xlabel('Epochs'); plt.ylabel('Loss') + plt.legend(); plt.grid(True) + + plt.tight_layout() + plt.savefig('loss_curves.png') + print("\nLoss curves saved to loss_curves.png") + + # Clean up DDP + ddp_cleanup() + +if __name__ == '__main__': + main()