from __future__ import annotations import argparse import json import math import os from typing import List, Sequence, Tuple import torch from torch.utils.data import DataLoader, random_split from dataset import HealthDataset, health_collate_fn from evaluation_age_time_dependent import EvalAgeConfig, evaluate_time_dependent_age_bins from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss from model import DelphiFork, SapDelphi, SimpleHead def _parse_floats(items: Sequence[str]) -> List[float]: out: List[float] = [] for x in items: x = x.strip() if not x: continue out.append(float(x)) return out def _parse_age_bin_edges(edges: Sequence[str]) -> List[float]: vals = _parse_floats(edges) if len(vals) < 2: raise ValueError("--age_bin_edges must have at least 2 values") for i in range(1, len(vals)): if not (vals[i] > vals[i - 1]): raise ValueError("--age_bin_edges must be strictly increasing") return vals def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]: return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)] def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float): if loss_type == "exponential": criterion = ExponentialNLLLoss(lambda_reg=lambda_reg) out_dims = [n_disease] return criterion, out_dims if loss_type == "discrete_time_cif": criterion = DiscreteTimeCIFNLLLoss( bin_edges=bin_edges, lambda_reg=lambda_reg) out_dims = [n_disease + 1, len(bin_edges)] return criterion, out_dims if loss_type == "pwe_cif": pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))] if len(pwe_edges) < 2: raise ValueError( "pwe_cif requires at least 2 finite bin edges (including 0)") if float(pwe_edges[0]) != 0.0: raise ValueError("pwe_cif requires bin_edges[0]==0.0") criterion = PiecewiseExponentialCIFNLLLoss( bin_edges=pwe_edges, lambda_reg=lambda_reg) n_bins = len(pwe_edges) - 1 out_dims = [n_disease, n_bins] return criterion, out_dims raise ValueError(f"Unsupported loss_type: {loss_type}") def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict): if model_type == "delphi_fork": return DelphiFork( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg["age_encoder"]), n_cont=int(dataset.n_cont), n_cate=int(dataset.n_cate), cate_dims=list(dataset.cate_dims), ) if model_type == "sap_delphi": return SapDelphi( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg["age_encoder"]), n_cont=int(dataset.n_cont), n_cate=int(dataset.n_cate), cate_dims=list(dataset.cate_dims), pretrained_weights_path=str( cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")), freeze_embeddings=bool(cfg.get("freeze_embeddings", True)), ) raise ValueError(f"Unsupported model_type: {model_type}") def main() -> None: parser = argparse.ArgumentParser( description="Delphi-2M-style age-bin time-dependent evaluation (writes per-bin and aggregated CSVs; aggregated includes agg_type={macro,weighted})") parser.add_argument( "--run_dir", type=str, required=True, help="Training run directory (contains best_model.pt and train_config.json)", ) parser.add_argument("--data_prefix", type=str, default=None) parser.add_argument("--split", type=str, choices=["train", "val", "test", "all"], default="val") parser.add_argument("--horizons", type=str, nargs="+", default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"]) parser.add_argument( "--age_bin_edges", type=str, nargs="+", default=["40", "45", "50", "55", "60", "65", "70", "75", "80"], help="Age bin edges in years (e.g., --age_bin_edges 40 45 50 ...). Bins are [edge[i], edge[i+1]).", ) parser.add_argument( "--topk_percent", type=float, nargs="+", default=[1, 5, 10, 20, 50], help="One or more K%% values for recall/precision@K%%", ) parser.add_argument("--n_mc", type=int, default=5) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--num_workers", type=int, default=0) parser.add_argument("--out_prefix", type=str, default=None, help="Output prefix for CSVs") args = parser.parse_args() ckpt_path = os.path.join(args.run_dir, "best_model.pt") cfg_path = os.path.join(args.run_dir, "train_config.json") if not os.path.exists(ckpt_path): raise SystemExit(f"Missing checkpoint: {ckpt_path}") if not os.path.exists(cfg_path): raise SystemExit(f"Missing config: {cfg_path}") with open(cfg_path, "r") as f: cfg = json.load(f) data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get( "data_prefix", "ukb") full_cov = bool(cfg.get("full_cov", False)) cov_list = None if full_cov else ["bmi", "smoking", "alcohol"] dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list) train_ratio = float(cfg.get("train_ratio", 0.7)) val_ratio = float(cfg.get("val_ratio", 0.15)) seed_split = int(cfg.get("random_seed", 42)) n_total = len(dataset) n_train = int(n_total * train_ratio) n_val = int(n_total * val_ratio) n_test = n_total - n_train - n_val train_ds, val_ds, test_ds = random_split( dataset, [n_train, n_val, n_test], generator=torch.Generator().manual_seed(seed_split), ) if args.split == "train": ds = train_ds elif args.split == "val": ds = val_ds elif args.split == "test": ds = test_ds else: ds = dataset loader = DataLoader( ds, batch_size=int(args.batch_size), shuffle=False, collate_fn=health_collate_fn, num_workers=int(args.num_workers), pin_memory=str(args.device).startswith("cuda"), ) criterion, out_dims = build_criterion_and_out_dims( loss_type=str(cfg["loss_type"]), n_disease=int(dataset.n_disease), bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]), lambda_reg=float(cfg.get("lambda_reg", 0.0)), ) model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg) head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims) device = torch.device(args.device) checkpoint = torch.load(ckpt_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"], strict=True) head.load_state_dict(checkpoint["head_state_dict"], strict=True) if "criterion_state_dict" in checkpoint: try: criterion.load_state_dict( checkpoint["criterion_state_dict"], strict=False) except Exception: pass model.to(device) head.to(device) criterion.to(device) age_edges = _parse_age_bin_edges(args.age_bin_edges) age_bins = _edges_to_bins(age_edges) eval_cfg = EvalAgeConfig( horizons_years=_parse_floats(args.horizons), age_bins=age_bins, topk_percents=[float(x) for x in args.topk_percent], n_mc=int(args.n_mc), seed=int(args.seed), cause_ids=None, ) df_by_bin, df_agg = evaluate_time_dependent_age_bins( model=model, head=head, criterion=criterion, dataloader=loader, n_disease=int(dataset.n_disease), cfg=eval_cfg, device=device, ) if args.out_prefix is None: out_prefix = os.path.join( args.run_dir, f"age_bin_time_dependent_{args.split}") else: out_prefix = args.out_prefix out_bin = out_prefix + "_by_bin.csv" out_agg = out_prefix + "_agg.csv" df_by_bin.to_csv(out_bin, index=False) df_agg.to_csv(out_agg, index=False) print(f"Wrote: {out_bin}") print(f"Wrote: {out_agg}") if __name__ == "__main__": main()