from __future__ import annotations import argparse import json import math import os from typing import List, Sequence import torch from torch.utils.data import DataLoader, random_split from dataset import HealthDataset, health_collate_fn from evaluation_time_dependent import EvalConfig, evaluate_time_dependent from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss from model import DelphiFork, SapDelphi, SimpleHead def _parse_floats(items: Sequence[str]) -> List[float]: out: List[float] = [] for x in items: x = x.strip() if not x: continue out.append(float(x)) return out def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float): if loss_type == "exponential": criterion = ExponentialNLLLoss(lambda_reg=lambda_reg) out_dims = [n_disease] return criterion, out_dims if loss_type == "discrete_time_cif": criterion = DiscreteTimeCIFNLLLoss( bin_edges=bin_edges, lambda_reg=lambda_reg) out_dims = [n_disease + 1, len(bin_edges)] return criterion, out_dims if loss_type == "pwe_cif": pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))] if len(pwe_edges) < 2: raise ValueError( "pwe_cif requires at least 2 finite bin edges (including 0)") if float(pwe_edges[0]) != 0.0: raise ValueError("pwe_cif requires bin_edges[0]==0.0") criterion = PiecewiseExponentialCIFNLLLoss( bin_edges=pwe_edges, lambda_reg=lambda_reg) n_bins = len(pwe_edges) - 1 out_dims = [n_disease, n_bins] return criterion, out_dims raise ValueError(f"Unsupported loss_type: {loss_type}") def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict): if model_type == "delphi_fork": return DelphiFork( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg["age_encoder"]), n_cont=int(dataset.n_cont), n_cate=int(dataset.n_cate), cate_dims=list(dataset.cate_dims), ) if model_type == "sap_delphi": return SapDelphi( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg["age_encoder"]), n_cont=int(dataset.n_cont), n_cate=int(dataset.n_cate), cate_dims=list(dataset.cate_dims), pretrained_weights_path=str( cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")), freeze_embeddings=bool(cfg.get("freeze_embeddings", True)), ) raise ValueError(f"Unsupported model_type: {model_type}") def main() -> None: parser = argparse.ArgumentParser( description="Time-dependent evaluation for DeepHealth") parser.add_argument( "--run_dir", type=str, required=True, help="Training run directory (contains best_model.pt and train_config.json)", ) parser.add_argument("--data_prefix", type=str, default=None, help="Dataset prefix (overrides config if provided)") parser.add_argument("--split", type=str, choices=["train", "val", "test", "all"], default="val") parser.add_argument("--horizons", type=str, nargs="+", default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"], help="One or more horizons (years)") parser.add_argument("--offset_years", type=float, default=0.0, help="Context selection offset (years before follow-up end)") parser.add_argument( "--topk_percent", type=float, nargs="+", default=[1, 5, 10, 20, 50], help="One or more K%% values for recall/precision@K%% (e.g., --topk_percent 1 5 10)", ) parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--num_workers", type=int, default=0, help="Keep 0 on Windows") parser.add_argument("--out_csv", type=str, default=None, help="Optional output CSV path") args = parser.parse_args() ckpt_path = os.path.join(args.run_dir, "best_model.pt") cfg_path = os.path.join(args.run_dir, "train_config.json") if not os.path.exists(ckpt_path): raise SystemExit(f"Missing checkpoint: {ckpt_path}") if not os.path.exists(cfg_path): raise SystemExit(f"Missing config: {cfg_path}") with open(cfg_path, "r") as f: cfg = json.load(f) data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get( "data_prefix", "ukb") # Match training covariate selection. full_cov = bool(cfg.get("full_cov", False)) cov_list = None if full_cov else ["bmi", "smoking", "alcohol"] dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list) # Recreate the same split scheme as train.py train_ratio = float(cfg.get("train_ratio", 0.7)) val_ratio = float(cfg.get("val_ratio", 0.15)) seed = int(cfg.get("random_seed", 42)) n_total = len(dataset) n_train = int(n_total * train_ratio) n_val = int(n_total * val_ratio) n_test = n_total - n_train - n_val train_ds, val_ds, test_ds = random_split( dataset, [n_train, n_val, n_test], generator=torch.Generator().manual_seed(seed), ) if args.split == "train": ds = train_ds elif args.split == "val": ds = val_ds elif args.split == "test": ds = test_ds else: ds = dataset loader = DataLoader( ds, batch_size=int(args.batch_size), shuffle=False, collate_fn=health_collate_fn, num_workers=int(args.num_workers), pin_memory=str(args.device).startswith("cuda"), ) criterion, out_dims = build_criterion_and_out_dims( loss_type=str(cfg["loss_type"]), n_disease=int(dataset.n_disease), bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]), lambda_reg=float(cfg.get("lambda_reg", 0.0)), ) model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg) head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims) device = torch.device(args.device) checkpoint = torch.load(ckpt_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"], strict=True) head.load_state_dict(checkpoint["head_state_dict"], strict=True) if "criterion_state_dict" in checkpoint: try: criterion.load_state_dict( checkpoint["criterion_state_dict"], strict=False) except Exception: pass model.to(device) head.to(device) criterion.to(device) eval_cfg = EvalConfig( horizons_years=_parse_floats(args.horizons), offset_years=float(args.offset_years), topk_percents=[float(x) for x in args.topk_percent], cause_ids=None, ) df = evaluate_time_dependent( model=model, head=head, criterion=criterion, dataloader=loader, n_disease=int(dataset.n_disease), cfg=eval_cfg, device=device, ) if args.out_csv is None: out_csv = os.path.join( args.run_dir, f"time_dependent_metrics_{args.split}.csv") else: out_csv = args.out_csv df.to_csv(out_csv, index=False) print(f"Wrote: {out_csv}") if __name__ == "__main__": main()