from __future__ import annotations import argparse import json import math import os import multiprocessing as mp from typing import List, Sequence, Tuple import pandas as pd import torch from torch.utils.data import DataLoader, random_split from dataset import HealthDataset, health_collate_fn from evaluation_age_time_dependent import ( EvalAgeConfig, aggregate_age_bin_results, evaluate_time_dependent_age_bins, ) from losses import DiscreteTimeCIFNLLLoss, ExponentialNLLLoss, PiecewiseExponentialCIFNLLLoss from model import DelphiFork, SapDelphi, SimpleHead def _parse_floats(items: Sequence[str]) -> List[float]: out: List[float] = [] for x in items: x = x.strip() if not x: continue out.append(float(x)) return out def _parse_age_bin_edges(edges: Sequence[str]) -> List[float]: vals = _parse_floats(edges) if len(vals) < 2: raise ValueError("--age_bin_edges must have at least 2 values") for i in range(1, len(vals)): if not (vals[i] > vals[i - 1]): raise ValueError("--age_bin_edges must be strictly increasing") return vals def _edges_to_bins(edges: Sequence[float]) -> List[Tuple[float, float]]: return [(float(edges[i]), float(edges[i + 1])) for i in range(len(edges) - 1)] def _parse_gpus(gpus: str | None) -> List[int]: if gpus is None: return [] s = gpus.strip() if not s: return [] parts = [p.strip() for p in s.split(",") if p.strip()] out: List[int] = [] for p in parts: out.append(int(p)) return out def _worker_eval_mcs_on_gpu( queue: "mp.Queue", *, run_dir: str, split: str, data_prefix_override: str | None, horizons: List[float], age_bins: List[Tuple[float, float]], topk_percents: List[float], n_mc: int, seed: int, batch_size: int, num_workers: int, gpu_id: int, mc_indices: List[int], out_path: str, ) -> None: """Worker process: evaluate a subset of MC indices on a single GPU.""" try: ckpt_path = os.path.join(run_dir, "best_model.pt") cfg_path = os.path.join(run_dir, "train_config.json") with open(cfg_path, "r") as f: cfg = json.load(f) data_prefix = ( data_prefix_override if data_prefix_override is not None else cfg.get("data_prefix", "ukb") ) full_cov = bool(cfg.get("full_cov", False)) cov_list = None if full_cov else ["bmi", "smoking", "alcohol"] dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list) train_ratio = float(cfg.get("train_ratio", 0.7)) val_ratio = float(cfg.get("val_ratio", 0.15)) seed_split = int(cfg.get("random_seed", 42)) n_total = len(dataset) n_train = int(n_total * train_ratio) n_val = int(n_total * val_ratio) n_test = n_total - n_train - n_val train_ds, val_ds, test_ds = random_split( dataset, [n_train, n_val, n_test], generator=torch.Generator().manual_seed(seed_split), ) if split == "train": ds = train_ds elif split == "val": ds = val_ds elif split == "test": ds = test_ds else: ds = dataset loader = DataLoader( ds, batch_size=int(batch_size), shuffle=False, collate_fn=health_collate_fn, num_workers=int(num_workers), pin_memory=True, ) criterion, out_dims = build_criterion_and_out_dims( loss_type=str(cfg["loss_type"]), n_disease=int(dataset.n_disease), bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]), lambda_reg=float(cfg.get("lambda_reg", 0.0)), ) model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg) head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims) device = torch.device(f"cuda:{int(gpu_id)}") checkpoint = torch.load(ckpt_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"], strict=True) head.load_state_dict(checkpoint["head_state_dict"], strict=True) if "criterion_state_dict" in checkpoint: try: criterion.load_state_dict( checkpoint["criterion_state_dict"], strict=False) except Exception: pass model.to(device) head.to(device) criterion.to(device) frames: List[pd.DataFrame] = [] for mc_idx in mc_indices: eval_cfg = EvalAgeConfig( horizons_years=horizons, age_bins=age_bins, topk_percents=topk_percents, n_mc=1, seed=int(seed), cause_ids=None, ) df_by_bin, _df_agg_unused = evaluate_time_dependent_age_bins( model=model, head=head, criterion=criterion, dataloader=loader, n_disease=int(dataset.n_disease), cfg=eval_cfg, device=device, mc_offset=int(mc_idx), ) frames.append(df_by_bin) df_all = pd.concat(frames, ignore_index=True) if len( frames) else pd.DataFrame() df_all.to_csv(out_path, index=False) queue.put({"ok": True, "out_path": out_path}) except Exception as e: queue.put({"ok": False, "error": repr(e)}) def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lambda_reg: float): if loss_type == "exponential": criterion = ExponentialNLLLoss(lambda_reg=lambda_reg) out_dims = [n_disease] return criterion, out_dims if loss_type == "discrete_time_cif": criterion = DiscreteTimeCIFNLLLoss( bin_edges=bin_edges, lambda_reg=lambda_reg) out_dims = [n_disease + 1, len(bin_edges)] return criterion, out_dims if loss_type == "pwe_cif": pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))] if len(pwe_edges) < 2: raise ValueError( "pwe_cif requires at least 2 finite bin edges (including 0)") if float(pwe_edges[0]) != 0.0: raise ValueError("pwe_cif requires bin_edges[0]==0.0") criterion = PiecewiseExponentialCIFNLLLoss( bin_edges=pwe_edges, lambda_reg=lambda_reg) n_bins = len(pwe_edges) - 1 out_dims = [n_disease, n_bins] return criterion, out_dims raise ValueError(f"Unsupported loss_type: {loss_type}") def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict): if model_type == "delphi_fork": return DelphiFork( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg["age_encoder"]), n_cont=int(dataset.n_cont), n_cate=int(dataset.n_cate), cate_dims=list(dataset.cate_dims), ) if model_type == "sap_delphi": return SapDelphi( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg["age_encoder"]), n_cont=int(dataset.n_cont), n_cate=int(dataset.n_cate), cate_dims=list(dataset.cate_dims), pretrained_weights_path=str( cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")), freeze_embeddings=bool(cfg.get("freeze_embeddings", True)), ) raise ValueError(f"Unsupported model_type: {model_type}") def main() -> None: parser = argparse.ArgumentParser( description="Delphi-2M-style age-bin time-dependent evaluation (writes per-bin and aggregated CSVs; aggregated includes agg_type={macro,weighted})") parser.add_argument( "--run_dir", type=str, required=True, help="Training run directory (contains best_model.pt and train_config.json)", ) parser.add_argument("--data_prefix", type=str, default=None) parser.add_argument("--split", type=str, choices=["train", "val", "test", "all"], default="val") parser.add_argument("--horizons", type=str, nargs="+", default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"]) parser.add_argument( "--age_bin_edges", type=str, nargs="+", default=["40", "45", "50", "55", "60", "65", "70", "75", "80"], help="Age bin edges in years (e.g., --age_bin_edges 40 45 50 ...). Bins are [edge[i], edge[i+1]).", ) parser.add_argument( "--topk_percent", type=float, nargs="+", default=[1, 5, 10, 20, 50], help="One or more K%% values for recall/precision@K%%", ) parser.add_argument("--n_mc", type=int, default=5) parser.add_argument("--seed", type=int, default=0) parser.add_argument( "--gpus", type=str, default=None, help="Comma-separated GPU ids to parallelize MC runs (one worker per GPU; one MC per GPU at a time). Example: --gpus 0,1,3", ) parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--num_workers", type=int, default=0) parser.add_argument("--out_prefix", type=str, default=None, help="Output prefix for CSVs") args = parser.parse_args() ckpt_path = os.path.join(args.run_dir, "best_model.pt") cfg_path = os.path.join(args.run_dir, "train_config.json") if not os.path.exists(ckpt_path): raise SystemExit(f"Missing checkpoint: {ckpt_path}") if not os.path.exists(cfg_path): raise SystemExit(f"Missing config: {cfg_path}") with open(cfg_path, "r") as f: cfg = json.load(f) data_prefix = args.data_prefix if args.data_prefix is not None else cfg.get( "data_prefix", "ukb") full_cov = bool(cfg.get("full_cov", False)) cov_list = None if full_cov else ["bmi", "smoking", "alcohol"] dataset = HealthDataset(data_prefix=data_prefix, covariate_list=cov_list) train_ratio = float(cfg.get("train_ratio", 0.7)) val_ratio = float(cfg.get("val_ratio", 0.15)) seed_split = int(cfg.get("random_seed", 42)) n_total = len(dataset) n_train = int(n_total * train_ratio) n_val = int(n_total * val_ratio) n_test = n_total - n_train - n_val train_ds, val_ds, test_ds = random_split( dataset, [n_train, n_val, n_test], generator=torch.Generator().manual_seed(seed_split), ) if args.split == "train": ds = train_ds elif args.split == "val": ds = val_ds elif args.split == "test": ds = test_ds else: ds = dataset loader = DataLoader( ds, batch_size=int(args.batch_size), shuffle=False, collate_fn=health_collate_fn, num_workers=int(args.num_workers), pin_memory=str(args.device).startswith("cuda"), ) criterion, out_dims = build_criterion_and_out_dims( loss_type=str(cfg["loss_type"]), n_disease=int(dataset.n_disease), bin_edges=cfg.get("bin_edges", [0.0, 1.0, float("inf")]), lambda_reg=float(cfg.get("lambda_reg", 0.0)), ) model = build_model(str(cfg["model_type"]), dataset=dataset, cfg=cfg) head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims) device = torch.device(args.device) checkpoint = torch.load(ckpt_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"], strict=True) head.load_state_dict(checkpoint["head_state_dict"], strict=True) if "criterion_state_dict" in checkpoint: try: criterion.load_state_dict( checkpoint["criterion_state_dict"], strict=False) except Exception: pass model.to(device) head.to(device) criterion.to(device) age_edges = _parse_age_bin_edges(args.age_bin_edges) age_bins = _edges_to_bins(age_edges) eval_cfg = EvalAgeConfig( horizons_years=_parse_floats(args.horizons), age_bins=age_bins, topk_percents=[float(x) for x in args.topk_percent], n_mc=int(args.n_mc), seed=int(args.seed), cause_ids=None, ) if args.out_prefix is None: out_prefix = os.path.join( args.run_dir, f"age_bin_time_dependent_{args.split}") else: out_prefix = args.out_prefix out_bin = out_prefix + "_by_bin.csv" out_agg = out_prefix + "_agg.csv" gpus = _parse_gpus(args.gpus) if len(gpus) <= 1: df_by_bin, df_agg = evaluate_time_dependent_age_bins( model=model, head=head, criterion=criterion, dataloader=loader, n_disease=int(dataset.n_disease), cfg=eval_cfg, device=device, ) df_by_bin.to_csv(out_bin, index=False) df_agg.to_csv(out_agg, index=False) print(f"Wrote: {out_bin}") print(f"Wrote: {out_agg}") return if not torch.cuda.is_available(): raise SystemExit("--gpus was provided but CUDA is not available") # Multi-GPU path: run MC indices in parallel across GPUs (one worker per GPU). mc_indices_all = list(range(int(args.n_mc))) per_gpu: List[Tuple[int, List[int]]] = [] for pos, gpu_id in enumerate(gpus): assigned = [i for i in mc_indices_all if (i % len(gpus)) == pos] if assigned: per_gpu.append((int(gpu_id), assigned)) ctx = mp.get_context("spawn") queue: "mp.Queue" = ctx.Queue() procs: List[mp.Process] = [] tmp_paths: List[str] = [] for gpu_id, mc_idxs in per_gpu: tmp_path = f"{out_prefix}__tmp_gpu{gpu_id}.csv" tmp_paths.append(tmp_path) p = ctx.Process( target=_worker_eval_mcs_on_gpu, kwargs=dict( queue=queue, run_dir=str(args.run_dir), split=str(args.split), data_prefix_override=( str(args.data_prefix) if args.data_prefix is not None else None ), horizons=_parse_floats(args.horizons), age_bins=age_bins, topk_percents=[float(x) for x in args.topk_percent], n_mc=int(args.n_mc), seed=int(args.seed), batch_size=int(args.batch_size), num_workers=int(args.num_workers), gpu_id=int(gpu_id), mc_indices=mc_idxs, out_path=tmp_path, ), ) p.start() procs.append(p) results = [queue.get() for _ in range(len(procs))] for p in procs: p.join() for r in results: if not r.get("ok", False): raise SystemExit(f"Worker failed: {r.get('error')}") frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)] df_by_bin = pd.concat(frames, ignore_index=True) if len( frames) else pd.DataFrame() df_agg = aggregate_age_bin_results(df_by_bin) df_by_bin.to_csv(out_bin, index=False) df_agg.to_csv(out_agg, index=False) # Best-effort cleanup. for p in tmp_paths: try: if os.path.exists(p): os.remove(p) except Exception: pass print(f"Wrote: {out_bin}") print(f"Wrote: {out_agg}") if __name__ == "__main__": main()