from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt from model import DelphiFork, SapDelphi from dataset import HealthDataset, health_collate_fn from tqdm import tqdm from torch.nn.utils import clip_grad_norm_ from torch.utils.data import random_split from torch.utils.data import DataLoader from torch.optim import AdamW import torch.nn as nn import torch import json import os import time import argparse import math import sys from dataclasses import asdict, dataclass, field from typing import Literal, Sequence @dataclass class TrainConfig: # Model Parameters model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork' loss_type: Literal['exponential', 'weibull', 'piecewise_exponential'] = 'weibull' age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal' full_cov: bool = False n_embd: int = 120 n_head: int = 12 n_layer: int = 12 pdrop: float = 0.1 lambda_reg: float = 1e-4 bin_edges: Sequence[float] = field( default_factory=lambda: [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0] ) rank: int = 16 # SapDelphi specific pretrained_emd_path: str = "icd10_sapbert_embeddings.npy" # Data Parameters data_prefix: str = "ukb" train_ratio: float = 0.7 val_ratio: float = 0.15 random_seed: int = 42 # Training Parameters batch_size: int = 128 max_epochs: int = 200 warmup_epochs: int = 10 patience: int = 10 min_lr: float = 1e-5 max_lr: float = 5e-4 grad_clip: float = 1.0 weight_decay: float = 1e-2 device: str = 'cuda' if torch.cuda.is_available() else 'cpu' num_workers: int = 0 prefetch_factor: int = 2 persistent_workers: bool = False def parse_args() -> TrainConfig: parser = argparse.ArgumentParser(description="Train Delphi Model") parser.add_argument("--model_type", type=str, choices=[ 'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.") parser.add_argument("--loss_type", type=str, choices=[ 'exponential', 'weibull', 'piecewise_exponential'], default='weibull', help="Type of loss function to use.") parser.add_argument("--age_encoder", type=str, choices=[ 'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.") parser.add_argument("--n_embd", type=int, default=120, help="Embedding dimension.") parser.add_argument("--n_head", type=int, default=12, help="Number of attention heads.") parser.add_argument("--n_layer", type=int, default=12, help="Number of transformer layers.") parser.add_argument("--pdrop", type=float, default=0.1, help="Dropout probability.") parser.add_argument("--lambda_reg", type=float, default=1e-4, help="Regularization weight.") parser.add_argument("--rank", type=int, default=16, help="Rank for low-rank parameterization (if applicable).") parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy", help="Path to pretrained embeddings for SapDelphi.") parser.add_argument("--data_prefix", type=str, default="ukb", help="Prefix for dataset files.") parser.add_argument("--full_cov", action='store_true', help="Whether to use full covariates.") parser.add_argument("--train_ratio", type=float, default=0.7, help="Training data ratio.") parser.add_argument("--val_ratio", type=float, default=0.15, help="Validation data ratio.") parser.add_argument("--random_seed", type=int, default=42, help="Random seed for data splitting.") parser.add_argument("--batch_size", type=int, default=128, help="Batch size.") parser.add_argument("--max_epochs", type=int, default=200, help="Maximum number of epochs.") parser.add_argument("--warmup_epochs", type=int, default=10, help="Number of warmup epochs.") parser.add_argument("--patience", type=int, default=10, help="Early stopping patience.") parser.add_argument("--min_lr", type=float, default=1e-5, help="Minimum learning rate.") parser.add_argument("--max_lr", type=float, default=5e-4, help="Maximum learning rate.") parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value.") parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay for optimizer.") parser.add_argument("--num_workers", type=int, default=0, help="DataLoader workers (0 is safest on Windows).") parser.add_argument("--prefetch_factor", type=int, default=2, help="DataLoader prefetch factor (only used when num_workers>0).") parser.add_argument("--persistent_workers", action='store_true', help="Keep DataLoader workers alive between epochs (only if num_workers>0).") parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help="Device to use for training.") args = parser.parse_args() return TrainConfig(**vars(args)) def get_num_params(model: nn.Module) -> int: return sum(p.numel() for p in model.parameters() if p.requires_grad) class Trainer: def __init__( self, cfg: TrainConfig, ): self.cfg = cfg self.device = cfg.device self.global_step = 0 use_cuda = str(self.device).startswith( "cuda") and torch.cuda.is_available() if use_cuda: torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True try: torch.set_float32_matmul_precision("high") except Exception: pass if cfg.full_cov: cov_list = None else: cov_list = ["bmi", "smoking", "alcohol"] dataset = HealthDataset( data_prefix=cfg.data_prefix, covariate_list=cov_list, ) print("Dataset loaded.") n_total = len(dataset) print(f"Total samples in dataset: {n_total}") print(f"Number of diseases: {dataset.n_disease}") print(f"Number of continuous covariates: {dataset.n_cont}") print(f"Number of categorical covariates: {dataset.n_cate}") self.train_data, self.val_data, _ = random_split( dataset, [ int(n_total * cfg.train_ratio), int(n_total * cfg.val_ratio), n_total - int(n_total * cfg.train_ratio) - int(n_total * cfg.val_ratio), ], generator=torch.Generator().manual_seed(cfg.random_seed), ) pin_memory = use_cuda loader_kwargs = dict( collate_fn=health_collate_fn, pin_memory=pin_memory, ) if cfg.num_workers > 0: loader_kwargs["num_workers"] = cfg.num_workers loader_kwargs["prefetch_factor"] = cfg.prefetch_factor loader_kwargs["persistent_workers"] = cfg.persistent_workers self.train_loader = DataLoader( self.train_data, batch_size=cfg.batch_size, shuffle=True, **loader_kwargs, ) self.val_loader = DataLoader( self.val_data, batch_size=cfg.batch_size, shuffle=False, **loader_kwargs, ) if cfg.loss_type == "exponential": self.criterion = ExponentialNLLLoss( lambda_reg=cfg.lambda_reg, ).to(self.device) n_dim = 1 elif cfg.loss_type == "piecewise_exponential": self.criterion = PiecewiseExponentialLoss( bin_edges=cfg.bin_edges, lambda_reg=cfg.lambda_reg, ).to(self.device) n_dim = len(cfg.bin_edges) - 1 elif cfg.loss_type == "weibull": self.criterion = WeibullNLLLoss( lambda_reg=cfg.lambda_reg, ).to(self.device) n_dim = 2 else: raise ValueError(f"Unsupported loss type: {cfg.loss_type}") if cfg.model_type == "delphi_fork": self.model = DelphiFork( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=cfg.n_embd, n_head=cfg.n_head, n_layer=cfg.n_layer, pdrop=cfg.pdrop, age_encoder_type=cfg.age_encoder, n_dim=n_dim, rank=cfg.rank, n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, ).to(self.device) elif cfg.model_type == "sap_delphi": self.model = SapDelphi( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=cfg.n_embd, n_head=cfg.n_head, n_layer=cfg.n_layer, pdrop=cfg.pdrop, age_encoder_type=cfg.age_encoder, n_dim=n_dim, rank=cfg.rank, n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, pretrained_weights_path=cfg.pretrained_emd_path, freeze_embeddings=True, ).to(self.device) else: raise ValueError(f"Unsupported model type: {cfg.model_type}") print(f"Model initialized: {cfg.model_type}") print(f"Number of trainable parameters: {get_num_params(self.model)}") self.optimizer = AdamW( self.model.parameters(), lr=cfg.max_lr, weight_decay=cfg.weight_decay, betas=(0.9, 0.99), ) self.total_steps = (len(self.train_loader) * cfg.max_epochs) print(f"Total optimization steps: {self.total_steps}") while True: cov_suffix = "fullcov" if cfg.full_cov else "partcov" name = f"{cfg.model_type}_{cfg.loss_type}_{cfg.age_encoder}_{cov_suffix}" timestamp = time.strftime("%Y%m%d-%H%M%S") model_dir = os.path.join("runs", f"{name}_{timestamp}") if not os.path.exists(model_dir): self.out_dir = model_dir os.makedirs(model_dir) break time.sleep(1) print(f"Output directory: {self.out_dir}") self.best_path = os.path.join(self.out_dir, "best_model.pt") self.global_step = 0 self.save_config() def save_config(self): cfg_path = os.path.join(self.out_dir, "train_config.json") with open(cfg_path, 'w') as f: json.dump(asdict(self.cfg), f, indent=4) print(f"Configuration saved to {cfg_path}") def compute_lr(self, current_step: int) -> float: cfg = self.cfg if current_step < cfg.warmup_epochs * len(self.train_loader): lr = cfg.max_lr * (current_step / (cfg.warmup_epochs * len(self.train_loader))) else: denom = (cfg.max_epochs - cfg.warmup_epochs) * \ len(self.train_loader) progress = (current_step - cfg.warmup_epochs * len(self.train_loader)) / denom lr = cfg.min_lr + 0.5 * \ (cfg.max_lr - cfg.min_lr) * (1 + math.cos(math.pi * progress)) return lr def train(self) -> None: history = [] best_val_score = float('inf') patience_counter = 0 for epoch in range(1, self.cfg.max_epochs + 1): model_for_logging = self.model.module if hasattr( self.model, "module") else self.model delta_scale = None theta_proj = getattr(model_for_logging, "theta_proj", None) if theta_proj is not None and hasattr(theta_proj, "delta_scale"): try: delta_scale = float( theta_proj.delta_scale.detach().cpu().item()) except Exception: delta_scale = None self.model.train() total_train_pairs = 0 total_train_nll = 0.0 total_train_reg = 0.0 total_train_log_scale_sq = 0.0 total_train_log_shape_sq = 0.0 pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100) batch_count = 0 for batch in pbar: ( event_seq, time_seq, cont_feats, cate_feats, sexes, ) = batch event_seq = event_seq.to(self.device, non_blocking=True) time_seq = time_seq.to(self.device, non_blocking=True) cont_feats = cont_feats.to(self.device, non_blocking=True) cate_feats = cate_feats.to(self.device, non_blocking=True) sexes = sexes.to(self.device, non_blocking=True) res = get_valid_pairs_and_dt(event_seq, time_seq, 2) if res is None: continue dt, b_prev, t_prev, b_next, t_next = res num_pairs = dt.size(0) self.optimizer.zero_grad() lr = self.compute_lr(self.global_step) self.optimizer.param_groups[0]['lr'] = lr logits = self.model( event_seq, time_seq, sexes, cont_feats, cate_feats, b_prev=b_prev, t_prev=t_prev, ) if isinstance(self.criterion, WeibullNLLLoss): eps = float(self.criterion.eps) shapes = torch.nn.functional.softplus(logits[..., 0]) + eps scales = torch.nn.functional.softplus(logits[..., 1]) + eps log_scale_sq = (torch.log(scales + eps) ** 2).mean() log_shape_sq = (torch.log(shapes + eps) ** 2).mean() else: log_scale_sq = None log_shape_sq = None target_event = event_seq[b_next, t_next] - 2 nll_vec, reg = self.criterion( logits, target_event, dt, reduction="none", ) nll = nll_vec.mean() loss = nll + reg batch_count += 1 total_train_pairs += num_pairs total_train_nll += nll_vec.sum().item() total_train_reg += reg.item() * num_pairs if log_scale_sq is not None: total_train_log_scale_sq += log_scale_sq.item() * num_pairs if log_shape_sq is not None: total_train_log_shape_sq += log_shape_sq.item() * num_pairs avg_train_nll = total_train_nll / total_train_pairs avg_train_reg = total_train_reg / total_train_pairs pbar.set_postfix({ "lr": lr, "NLL": avg_train_nll, "Reg": avg_train_reg, }) loss.backward() if self.cfg.grad_clip > 0: clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip) self.optimizer.step() self.global_step += 1 if batch_count == 0: print("No valid batches in this epoch, skipping validation.") continue train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0 train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0 train_log_scale_sq = ( total_train_log_scale_sq / total_train_pairs if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss) else None ) train_log_shape_sq = ( total_train_log_shape_sq / total_train_pairs if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss) else None ) self.model.eval() total_val_pairs = 0 total_val_nll = 0.0 total_val_reg = 0.0 total_val_log_scale_sq = 0.0 total_val_log_shape_sq = 0.0 with torch.no_grad(): val_pbar = tqdm(self.val_loader, desc="Validation") for batch in val_pbar: ( event_seq, time_seq, cont_feats, cate_feats, sexes, ) = batch event_seq = event_seq.to(self.device, non_blocking=True) time_seq = time_seq.to(self.device, non_blocking=True) cont_feats = cont_feats.to(self.device, non_blocking=True) cate_feats = cate_feats.to(self.device, non_blocking=True) sexes = sexes.to(self.device, non_blocking=True) res = get_valid_pairs_and_dt(event_seq, time_seq, 2) if res is None: continue dt, b_prev, t_prev, b_next, t_next = res num_pairs = dt.size(0) logits = self.model( event_seq, time_seq, sexes, cont_feats, cate_feats, b_prev=b_prev, t_prev=t_prev ) if isinstance(self.criterion, WeibullNLLLoss): eps = float(self.criterion.eps) shapes = torch.nn.functional.softplus( logits[..., 0]) + eps scales = torch.nn.functional.softplus( logits[..., 1]) + eps log_scale_sq = (torch.log(scales + eps) ** 2).mean() log_shape_sq = (torch.log(shapes + eps) ** 2).mean() else: log_scale_sq = None log_shape_sq = None target_events = event_seq[b_next, t_next] - 2 nll, reg = self.criterion( logits, target_events, dt, reduction="none", ) batch_nll_sum = nll.sum().item() total_val_nll += batch_nll_sum total_val_reg += reg.item() * num_pairs if log_scale_sq is not None: total_val_log_scale_sq += log_scale_sq.item() * num_pairs if log_shape_sq is not None: total_val_log_shape_sq += log_shape_sq.item() * num_pairs total_val_pairs += num_pairs current_val_avg_nll = total_val_nll / \ total_val_pairs if total_val_pairs > 0 else 0.0 current_val_avg_reg = total_val_reg / \ total_val_pairs if total_val_pairs > 0 else 0.0 val_pbar.set_postfix({ "NLL": f"{current_val_avg_nll:.4f}", "Reg": f"{current_val_avg_reg:.4f}", }) val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0 val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0 val_log_scale_sq = ( total_val_log_scale_sq / total_val_pairs if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss) else None ) val_log_shape_sq = ( total_val_log_shape_sq / total_val_pairs if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss) else None ) history.append({ "epoch": epoch, "train_nll": train_nll, "train_reg": train_reg, "val_nll": val_nll, "val_reg": val_reg, "delta_scale": delta_scale, "train_log_scale_sq": train_log_scale_sq, "train_log_shape_sq": train_log_shape_sq, "val_log_scale_sq": val_log_scale_sq, "val_log_shape_sq": val_log_shape_sq, }) tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:") tqdm.write(f" Train NLL: {train_nll:.4f}") tqdm.write(f" Train Reg: {train_reg:.4f}") tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC") tqdm.write(f" Val Reg: {val_reg:.4f}") if delta_scale is not None: tqdm.write(f" Delta scale: {delta_scale:.6g}") if train_log_scale_sq is not None and train_log_shape_sq is not None: tqdm.write( f" Train log(scale+eps)^2 mean: {train_log_scale_sq:.6g}") tqdm.write( f" Train log(shape+eps)^2 mean: {train_log_shape_sq:.6g}") if val_log_scale_sq is not None and val_log_shape_sq is not None: tqdm.write( f" Val log(scale+eps)^2 mean: {val_log_scale_sq:.6g}") tqdm.write( f" Val log(shape+eps)^2 mean: {val_log_shape_sq:.6g}") with open(os.path.join(self.out_dir, "training_history.json"), "w") as f: json.dump(history, f, indent=4) # Check for improvement if val_nll < best_val_score: best_val_score = val_nll patience_counter = 0 tqdm.write(" āœ“ New best validation score. Saving checkpoint.") torch.save({ "epoch": epoch, "global_step": self.global_step, "model_state_dict": self.model.state_dict(), "criterion_state_dict": self.criterion.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), }, self.best_path) else: patience_counter += 1 if epoch >= self.cfg.warmup_epochs and patience_counter >= self.cfg.patience: tqdm.write( f"\n⚠ No improvement in validation score for {patience_counter} epochs. Early stopping.") return tqdm.write( f" No improvement (patience: {patience_counter}/{self.cfg.patience})") tqdm.write("\nšŸŽ‰ Training complete!") if __name__ == "__main__": cfg = parse_args() trainer = Trainer(cfg) trainer.train()