diff --git a/train.py b/train.py new file mode 100644 index 0000000..af147ea --- /dev/null +++ b/train.py @@ -0,0 +1,481 @@ +import json +import os +import time +import argparse +import math +from dataclasses import asdict, dataclass +from typing import Literal, Sequence +from pathlib import Path + +import torch +import torch.nn as nn +from torch.optim import AdamW +from torch.utils.data import DataLoader +from torch.utils.data import random_split +from torch.nn.utils import clip_grad_norm_ +from tqdm import tqdm + +from dataset import HealthDataset, health_collate_fn +from model import DelphiFork, SapDelphi +from losses import ExponentialNLLLoss, WeibullNLLLoss, get_valid_pairs_and_dt + + +@dataclass +class TrainConfig: + # Model Parameters + model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork' + loss_type: Literal['exponential', 'weibull'] = 'weibull' + age_encoder: Literal['sinusoidal', 'learned'] = 'learned' + full_cov: bool = False + n_embd: int = 120 + n_head: int = 12 + n_layer: int = 12 + pdrop: float = 0.1 + lambda_reg: float = 1e-4 + # SapDelphi specific + pretrained_emd_path: str = "icd10_sapbert_embeddings.npy" + # Data Parameters + data_prefix: str = "ukb" + train_ratio: float = 0.7 + val_ratio: float = 0.15 + random_seed: int = 42 + # Training Parameters + batch_size: int = 128 + max_epochs: int = 200 + warmup_epochs: int = 10 + patience: int = 10 + min_lr: float = 1e-5 + max_lr: float = 5e-4 + grad_clip: float = 1.0 + weight_decay: float = 1e-2 + device: str = 'cuda' if torch.cuda.is_available() else 'cpu' + # EMA parameters + ema_decay: float = 0.999 + + +def parse_args() -> TrainConfig: + parser = argparse.ArgumentParser(description="Train Delphi Model") + parser.add_argument("--model_type", type=str, choices=[ + 'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.") + parser.add_argument("--loss_type", type=str, choices=[ + 'exponential', 'weibull'], default='weibull', help="Type of loss function to use.") + parser.add_argument("--age_encoder", type=str, choices=[ + 'sinusoidal', 'learned'], default='learned', help="Type of age encoder to use.") + parser.add_argument("--n_embd", type=int, default=120, + help="Embedding dimension.") + parser.add_argument("--n_head", type=int, default=12, + help="Number of attention heads.") + parser.add_argument("--n_layer", type=int, default=12, + help="Number of transformer layers.") + parser.add_argument("--pdrop", type=float, default=0.1, + help="Dropout probability.") + parser.add_argument("--lambda_reg", type=float, + default=1e-4, help="Regularization weight.") + parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy", + help="Path to pretrained embeddings for SapDelphi.") + parser.add_argument("--data_prefix", type=str, + default="ukb", help="Prefix for dataset files.") + parser.add_argument("--full_cov", action='store_true', + help="Whether to use full covariates.") + parser.add_argument("--train_ratio", type=float, + default=0.7, help="Training data ratio.") + parser.add_argument("--val_ratio", type=float, + default=0.15, help="Validation data ratio.") + parser.add_argument("--random_seed", type=int, default=42, + help="Random seed for data splitting.") + parser.add_argument("--batch_size", type=int, + default=128, help="Batch size.") + parser.add_argument("--max_epochs", type=int, default=200, + help="Maximum number of epochs.") + parser.add_argument("--warmup_epochs", type=int, + default=10, help="Number of warmup epochs.") + parser.add_argument("--patience", type=int, default=10, + help="Early stopping patience.") + parser.add_argument("--min_lr", type=float, default=1e-5, + help="Minimum learning rate.") + parser.add_argument("--max_lr", type=float, default=5e-4, + help="Maximum learning rate.") + parser.add_argument("--grad_clip", type=float, + default=1.0, help="Gradient clipping value.") + parser.add_argument("--weight_decay", type=float, + default=1e-2, help="Weight decay for optimizer.") + parser.add_argument("--ema_decay", type=float, + default=0.999, help="EMA decay rate.") + parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu', + help="Device to use for training.") + args = parser.parse_args() + return TrainConfig(**vars(args)) + + +def get_num_params(model: nn.Module) -> int: + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +class Trainer: + def __init__( + self, + cfg: TrainConfig, + ): + + self.cfg = cfg + self.device = cfg.device + self.global_step = 0 + + if cfg.full_cov: + cov_list = None + else: + cov_list = ["bmi", "smoking", "alcohol"] + dataset = HealthDataset( + data_prefix=cfg.data_prefix, + covariate_list=cov_list, + ) + print("Dataset loaded.") + n_total = len(dataset) + print(f"Total samples in dataset: {n_total}") + print(f"Number of diseases: {dataset.n_disease}") + print(f"Number of continuous covariates: {dataset.n_cont}") + print(f"Number of categorical covariates: {dataset.n_cate}") + self.train_data, self.val_data, _ = random_split( + dataset, + [ + int(n_total * cfg.train_ratio), + int(n_total * cfg.val_ratio), + n_total - int(n_total * cfg.train_ratio) - + int(n_total * cfg.val_ratio), + ], + generator=torch.Generator().manual_seed(cfg.random_seed), + ) + self.train_loader = DataLoader( + self.train_data, + batch_size=cfg.batch_size, + shuffle=True, + collate_fn=health_collate_fn, + ) + self.val_loader = DataLoader( + self.val_data, + batch_size=cfg.batch_size, + shuffle=False, + collate_fn=health_collate_fn, + ) + + if cfg.loss_type == "exponential": + self.criterion = ExponentialNLLLoss( + lambda_reg=cfg.lambda_reg, + ).to(self.device) + n_dim = 1 + elif cfg.loss_type == "weibull": + self.criterion = WeibullNLLLoss( + lambda_reg=cfg.lambda_reg, + ).to(self.device) + n_dim = 2 + else: + raise ValueError(f"Unsupported loss type: {cfg.loss_type}") + + if cfg.model_type == "delphi_fork": + self.model = DelphiFork( + n_disease=dataset.n_disease, + n_embd=cfg.n_embd, + n_head=cfg.n_head, + n_layer=cfg.n_layer, + pdrop=cfg.pdrop, + age_encoder=cfg.age_encoder, + n_dim=n_dim, + n_cont=dataset.n_cont, + n_cate=dataset.n_cate, + ).to(self.device) + elif cfg.model_type == "sap_delphi": + self.model = SapDelphi( + n_disease=dataset.n_disease, + n_embd=cfg.n_embd, + n_head=cfg.n_head, + n_layer=cfg.n_layer, + pdrop=cfg.pdrop, + age_encoder=cfg.age_encoder, + n_dim=n_dim, + n_cont=dataset.n_cont, + n_cate=dataset.n_cate, + pretrained_emd_path=cfg.pretrained_emd_path, + freeze_pretrained_emd=True, + ).to(self.device) + else: + raise ValueError(f"Unsupported model type: {cfg.model_type}") + print(f"Model initialized: {cfg.model_type}") + print(f"Number of trainable parameters: {get_num_params(self.model)}") + + # Initialize EMA model + self.ema_model = None + if cfg.ema_decay < 1.0: + if cfg.model_type == "delphi_fork": + self.ema_model = DelphiFork( + n_disease=dataset.n_disease, + n_embd=cfg.n_embd, + n_head=cfg.n_head, + n_layer=cfg.n_layer, + pdrop=cfg.pdrop, + age_encoder=cfg.age_encoder, + n_dim=n_dim, + n_cont=dataset.n_cont, + n_cate=dataset.n_cate, + ).to(self.device) + elif cfg.model_type == "sap_delphi": + self.ema_model = SapDelphi( + n_disease=dataset.n_disease, + n_embd=cfg.n_embd, + n_head=cfg.n_head, + n_layer=cfg.n_layer, + pdrop=cfg.pdrop, + age_encoder=cfg.age_encoder, + n_dim=n_dim, + n_cont=dataset.n_cont, + n_cate=dataset.n_cate, + pretrained_emd_path=cfg.pretrained_emd_path, + freeze_pretrained_emd=True, + ).to(self.device) + else: + raise ValueError(f"Unsupported model type: {cfg.model_type}") + self.ema_model.load_state_dict(self.model.state_dict()) + for param in self.ema_model.parameters(): + param.requires_grad = False + print("EMA model initialized.") + self.optimizer = AdamW( + self.model.parameters(), + lr=cfg.max_lr, + weight_decay=cfg.weight_decay, + betas=(0.9, 0.99), + ) + self.total_steps = (len(self.train_loader) * + cfg.max_epochs) + print(f"Total optimization steps: {self.total_steps}") + while True: + cov_suffix = "fullcov" if cfg.full_cov else "partcov" + name = f"{cfg.model_type}_{cfg.loss_type}_{cfg.age_encoder}_{cov_suffix}" + timestamp = time.strftime("%Y%m%d-%H%M%S") + model_dir = os.path.join("runs", f"{name}_{timestamp}") + if not os.path.exists(model_dir): + self.out_dir = model_dir + os.makedirs(model_dir) + break + time.sleep(1) + print(f"Output directory: {self.out_dir}") + self.best_path = os.path.join(self.out_dir, "best_model.pt") + self.global_step = 0 + self.save_config() + + def save_config(self): + cfg_path = os.path.join(self.out_dir, "train_config.json") + with open(cfg_path, 'w') as f: + json.dump(asdict(self.cfg), f, indent=4) + print(f"Configuration saved to {cfg_path}") + + def update_ema(self): + if self.ema_model is None: + return + decay = self.cfg.ema_decay + with torch.no_grad(): + model_params = dict(self.model.named_parameters()) + ema_params = dict(self.ema_model.named_parameters()) + for name in model_params.keys(): + ema_params[name].data.mul_(decay).add_( + model_params[name].data, alpha=1 - decay) + + def compute_lr(self, current_step: int) -> float: + cfg = self.cfg + if current_step < cfg.warmup_epochs * len(self.train_loader): + lr = cfg.max_lr * (current_step / + (cfg.warmup_epochs * len(self.train_loader))) + else: + denom = (cfg.max_epochs - cfg.warmup_epochs) * \ + len(self.train_loader) + progress = (current_step - cfg.warmup_epochs * + len(self.train_loader)) / denom + lr = cfg.min_lr + 0.5 * \ + (cfg.max_lr - cfg.min_lr) * (1 + math.cos(math.pi * progress)) + return lr + + def train(self) -> None: + history = [] + best_val_loss = float('inf') + patience_counter = 0 + for epoch in range(1, self.cfg.max_epochs + 1): + self.model.train() + running_nll = 0.0 + running_reg = 0.0 + pbar = tqdm(self.train_loader, + desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100) + for batch in pbar: + ( + event_seq, + time_seq, + cont_feats, + cate_feats, + sexes, + ) = batch + event_seq = event_seq.to(self.device) + time_seq = time_seq.to(self.device) + cont_feats = cont_feats.to(self.device) + cate_feats = cate_feats.to(self.device) + sexes = sexes.to(self.device) + res = get_valid_pairs_and_dt(event_seq, time_seq, 2) + if res is None: + continue + dt, b_prev, t_prev, b_next, t_next = res + self.optimizer.zero_grad() + lr = self.compute_lr(self.global_step) + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + logits = self.model( + event_seq, + time_seq, + sexes, + cont_feats, + cate_feats, + b_prev=b_prev, + t_prev=t_prev, + ) + target_event = event_seq[b_next, t_next] - 2 + nll_vec, reg = self.criterion( + logits, + target_event, + dt, + reduction="none", + ) + finite_mask = torch.isfinite(nll_vec) + if not finite_mask.any(): + continue + nll_vec = nll_vec[finite_mask] + nll = nll_vec.mean() + + loss = nll + reg + batch_count += 1 + running_nll += nll.item() + running_reg += reg.item() + pbar.set_postfix({ + "lr": lr, + "NLL": running_nll / batch_count, + "Reg": running_reg / batch_count, + }) + loss.backward() + if self.cfg.grad_clip > 0: + clip_grad_norm_(self.model.parameters(), + self.cfg.grad_clip) + self.optimizer.step() + self.update_ema() + self.global_step += 1 + + if batch_count == 0: + print("No valid batches in this epoch, skipping validation.") + continue + + train_nll = running_nll / batch_count + train_reg = running_reg / batch_count + + self.ema_model.eval() + total_val_pairs = 0 + total_val_nll = 0.0 + total_val_reg = 0.0 + with torch.no_grad(): + val_pbar = tqdm(self.val_loader, desc="Validation") + for batch in val_pbar: + ( + event_seq, + time_seq, + cont_feats, + cate_feats, + sexes, + ) = batch + event_seq = event_seq.to(self.device) + time_seq = time_seq.to(self.device) + cont_feats = cont_feats.to(self.device) + cate_feats = cate_feats.to(self.device) + sexes = sexes.to(self.device) + res = get_valid_pairs_and_dt(event_seq, time_seq, 2) + if res is None: + continue + dt, b_prev, t_prev, b_next, t_next = res + num_pairs = dt.size(0) + logits = self.ema_model( + event_seq, + time_seq, + sexes, + cont_feats, + cate_feats, + b_prev=b_prev, + t_prev=t_prev + ) + target_events = event_seq[b_next, t_next] - 2 + nll, reg = self.criterion( + logits, + target_events, + dt, + reduction="none", + ) + batch_nll_sum = nll.sum().item() + total_val_nll += batch_nll_sum + total_val_reg += reg.item() * num_pairs + total_val_pairs += num_pairs + + current_val_avg_nll = total_val_nll / \ + total_val_pairs if total_val_pairs > 0 else 0.0 + current_val_avg_reg = total_val_reg / \ + total_val_pairs if total_val_pairs > 0 else 0.0 + + val_pbar.set_postfix({ + "NLL": f"{current_val_avg_nll:.4f}", + "Reg": f"{current_val_avg_reg:.4f}", + }) + + val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0 + val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0 + + history.append({ + "epoch": epoch, + "train_nll": train_nll, + "train_reg": train_reg, + "val_nll": val_nll, + "val_reg": val_reg, + }) + + tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:") + tqdm.write(f" Train NLL: {train_nll:.4f}") + tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC") + + with open(os.path.join(self.out_dir, "training_history.json"), "w") as f: + json.dump(history, f, indent=4) + + # Check for improvement + if val_nll < best_val_score: + best_val_score = val_nll + patient_counter = 0 + tqdm.write(" āœ“ New best validation score. Saving checkpoint.") + + torch.save({ + "epoch": epoch, + "global_step": self.global_step, + "model_state_dict": self.ema_model.state_dict(), + "criterion_state_dict": self.criterion.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + }, self.best_path) + else: + patient_counter += 1 + if epoch+1 >= self.cfg.min_epochs and patient_counter >= self.cfg.patient_epochs: + tqdm.write( + f"\n⚠ No improvement in validation score for {patient_counter} epochs. Early stopping.") + return + tqdm.write( + f" No improvement (patience: {patient_counter}/{self.cfg.patient_epochs})") + + torch.save({ + "epoch": epoch, + "global_step": self.global_step, + "model_state_dict": self.ema_model.state_dict(), + "criterion_state_dict": self.criterion.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + }, self.last_path) + + tqdm.write("\nšŸŽ‰ Training complete!") + + +if __name__ == "__main__": + cfg = parse_args() + trainer = Trainer(cfg) + trainer.train()