import argparse import csv import json import math import os import random import statistics from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader, random_split from dataset import HealthDataset, health_collate_fn from losses import DiscreteTimeCIFNLLLoss from model import DelphiFork, SapDelphi, SimpleHead # ============================================================ # Constants / defaults (aligned with evaluate_prompt.md) # ============================================================ DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")] DEFAULT_EVAL_HORIZONS = [0.72, 1.61, 3.84, 10.0] DAYS_PER_YEAR = 365.25 # ============================================================ # Model specs # ============================================================ @dataclass(frozen=True) class ModelSpec: name: str model_type: str # delphi_fork | sap_delphi loss_type: str # exponential | discrete_time_cif full_cov: bool checkpoint_path: str # ============================================================ # Determinism # ============================================================ def set_deterministic(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # ============================================================ # Utilities # ============================================================ def _parse_bool(x: Any) -> bool: if isinstance(x, bool): return x s = str(x).strip().lower() if s in {"true", "1", "yes", "y"}: return True if s in {"false", "0", "no", "n"}: return False raise ValueError(f"Cannot parse boolean: {x!r}") def load_models_json(path: str) -> List[ModelSpec]: with open(path, "r") as f: data = json.load(f) if not isinstance(data, list): raise ValueError("models_json must be a list of model entries") specs: List[ModelSpec] = [] for row in data: specs.append( ModelSpec( name=str(row["name"]), model_type=str(row["model_type"]), loss_type=str(row["loss_type"]), full_cov=_parse_bool(row["full_cov"]), checkpoint_path=str(row["checkpoint_path"]), ) ) return specs def load_train_config_for_checkpoint(checkpoint_path: str) -> Dict[str, Any]: run_dir = os.path.dirname(os.path.abspath(checkpoint_path)) cfg_path = os.path.join(run_dir, "train_config.json") with open(cfg_path, "r") as f: cfg = json.load(f) return cfg def build_eval_subset( dataset: HealthDataset, train_ratio: float, val_ratio: float, seed: int, split: str, ): n_total = len(dataset) n_train = int(n_total * train_ratio) n_val = int(n_total * val_ratio) n_test = n_total - n_train - n_val train_ds, val_ds, test_ds = random_split( dataset, [n_train, n_val, n_test], generator=torch.Generator().manual_seed(seed), ) if split == "train": return train_ds if split == "val": return val_ds if split == "test": return test_ds if split == "all": return dataset raise ValueError("split must be one of: train, val, test, all") # ============================================================ # Context selection (anti-leakage) # ============================================================ def select_context_indices( event_seq: torch.Tensor, time_seq: torch.Tensor, offset_years: float, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Select per-sample prediction context index. IMPORTANT SEMANTICS: - The last observed token time is treated as the FOLLOW-UP END time. - We pick the last valid token with time <= (followup_end_time - offset). - We do NOT interpret followup_end_time as an event time. Returns: keep_mask: (B,) bool, which samples have a valid context t_ctx: (B,) long, index into sequence t_ctx_time: (B,) float, time (days) at context """ # valid tokens are event != 0 (padding is 0) valid = event_seq != 0 lengths = valid.sum(dim=1) last_idx = torch.clamp(lengths - 1, min=0) b = torch.arange(event_seq.size(0), device=event_seq.device) followup_end_time = time_seq[b, last_idx] t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR) eligible = valid & (time_seq <= t_cut.unsqueeze(1)) eligible_counts = eligible.sum(dim=1) keep = eligible_counts > 0 t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long) t_ctx_time = time_seq[b, t_ctx] return keep, t_ctx, t_ctx_time def next_event_after_context( event_seq: torch.Tensor, time_seq: torch.Tensor, t_ctx: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return next disease event after context. Returns: dt_years: (B,) float, time to next disease in years; +inf if none cause: (B,) long, disease id in [0,K) for next event; -1 if none """ B, L = event_seq.shape b = torch.arange(B, device=event_seq.device) t0 = time_seq[b, t_ctx] # Allow same-day events while excluding the context token itself. # We rely on time-sorted sequences and select the FIRST valid future event by index. idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) idx_min = torch.where( future, idxs, torch.full_like(idxs, L)).min(dim=1).values has = idx_min < L t_next = torch.where(has, idx_min, torch.zeros_like(idx_min)) t_next_time = time_seq[b, t_next] dt_days = t_next_time - t0 dt_years = dt_days / DAYS_PER_YEAR dt_years = torch.where( has, dt_years, torch.full_like(dt_years, float("inf"))) cause_token = event_seq[b, t_next] cause = (cause_token - 2).to(torch.long) cause = torch.where(has, cause, torch.full_like(cause, -1)) return dt_years, cause def multi_hot_ever_within_horizon( event_seq: torch.Tensor, time_seq: torch.Tensor, t_ctx: torch.Tensor, tau_years: float, n_disease: int, ) -> torch.Tensor: """Binary labels: disease k occurs within tau after context (any occurrence).""" B, L = event_seq.shape b = torch.arange(B, device=event_seq.device) t0 = time_seq[b, t_ctx] t1 = t0 + (tau_years * DAYS_PER_YEAR) idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) # Include same-day events after context, exclude any token at/before context index. in_window = ( (idxs > t_ctx.unsqueeze(1)) & (time_seq >= t0.unsqueeze(1)) & (time_seq <= t1.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) ) if not in_window.any(): return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) b_idx, t_idx = in_window.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) y[b_idx, disease_ids] = True return y def multi_hot_ever_after_context_anytime( event_seq: torch.Tensor, t_ctx: torch.Tensor, n_disease: int, ) -> torch.Tensor: """Binary labels: disease k occurs ANYTIME after the prediction context. This is Delphi2M-compatible for Task A case/control definition. Same-day events are included as long as they occur after the context token index. """ B, L = event_seq.shape idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) if not future.any(): return y b_idx, t_idx = future.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) y[b_idx, disease_ids] = True return y def multi_hot_selected_causes_within_horizon( event_seq: torch.Tensor, time_seq: torch.Tensor, t_ctx: torch.Tensor, tau_years: float, cause_ids: torch.Tensor, n_disease: int, ) -> torch.Tensor: """Labels for selected causes only: does cause k occur within tau after context?""" B, L = event_seq.shape device = event_seq.device b = torch.arange(B, device=device) t0 = time_seq[b, t_ctx] t1 = t0 + (tau_years * DAYS_PER_YEAR) idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) in_window = ( (idxs > t_ctx.unsqueeze(1)) & (time_seq >= t0.unsqueeze(1)) & (time_seq <= t1.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) ) out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device) if not in_window.any(): return out b_idx, t_idx = in_window.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) # Filter to selected causes via a boolean membership mask over the global disease space. selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device) selected[cause_ids] = True keep = selected[disease_ids] if not keep.any(): return out b_idx = b_idx[keep] disease_ids = disease_ids[keep] # Map disease_id -> local index in cause_ids # Build a lookup table (global disease space) where lookup[disease_id] = local_index lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device) lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device) local = lookup[disease_ids] out[b_idx, local] = True return out # ============================================================ # CIF conversion # ============================================================ def cifs_from_exponential_logits( logits: torch.Tensor, taus: Sequence[float], eps: float = 1e-6, return_survival: bool = False, ) -> torch.Tensor: """Convert exponential cause-specific logits -> CIFs at taus. logits: (B, K) returns: (B, K, H) or (cif, survival) if return_survival """ hazards = F.softplus(logits) + eps total = hazards.sum(dim=1, keepdim=True) # (B,1) taus_t = torch.tensor(list(taus), device=logits.device, dtype=hazards.dtype).view(1, 1, -1) total_h = total.unsqueeze(-1) # (B,1,1) # (1 - exp(-Lambda * tau)) one_minus_surv = 1.0 - torch.exp(-total_h * taus_t) frac = hazards / torch.clamp(total, min=eps) cif = frac.unsqueeze(-1) * one_minus_surv # (B,K,H) # If total==0, set to 0 cif = torch.where(total_h > 0, cif, torch.zeros_like(cif)) if not return_survival: return cif survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H) # Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors. nonzero = (total.squeeze(1) > 0).unsqueeze(1) survival = torch.where(nonzero, survival, torch.ones_like(survival)) return cif, survival def cifs_from_discrete_time_logits( logits: torch.Tensor, bin_edges: Sequence[float], taus: Sequence[float], return_survival: bool = False, ) -> torch.Tensor: """Convert discrete-time CIF logits -> CIFs at taus. logits: (B, K+1, n_bins+1) bin_edges: len=n_bins+1 (including 0 and inf) taus: subset of finite bin edges (recommended) returns: (B, K, H) or (cif, survival) if return_survival """ if logits.ndim != 3: raise ValueError("Expected logits shape (B, K+1, n_bins+1)") B, K_plus_1, n_bins_plus_1 = logits.shape K = K_plus_1 - 1 edges = [float(x) for x in bin_edges] # drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf finite_edges = [e for e in edges[1:] if math.isfinite(e)] n_bins = len(finite_edges) if n_bins_plus_1 != len(edges): raise ValueError("logits last dim must match len(bin_edges)") probs = torch.softmax(logits, dim=1) # (B, K+1, n_bins+1) # use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot) hazards = probs[:, :K, 1: 1 + n_bins] # (B,K,n_bins) p_comp = probs[:, K, 1: 1 + n_bins] # (B,n_bins) # survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v] ones = torch.ones((B, 1), device=logits.device, dtype=probs.dtype) cum = torch.cumprod(p_comp, dim=1) s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # (B, n_bins) cif_bins = torch.cumsum(s_prev.unsqueeze( 1) * hazards, dim=2) # (B,K,n_bins) # Robust mapping from tau -> edge index (floating-point safe). # taus are expected to align with bin edges, but may differ slightly due to parsing/serialization. finite_edges_arr = np.asarray(finite_edges, dtype=float) tau_to_idx: List[int] = [] for tau in taus: tau_f = float(tau) if not math.isfinite(tau_f): raise ValueError("taus must be finite for discrete-time CIF") diffs = np.abs(finite_edges_arr - tau_f) j = int(np.argmin(diffs)) if diffs[j] > 1e-6: raise ValueError( f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})" ) tau_to_idx.append(j) idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long) cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H) if not return_survival: return cif # Survival at each horizon = prod_{u <= idx[h]} p_comp[u] survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v] survival = survival_bins.index_select(dim=1, index=idx) # (B,H) return cif, survival # ============================================================ # CIF integrity checks # ============================================================ def check_cif_integrity( cause_cif: np.ndarray, horizons: Sequence[float], *, tol: float = 1e-6, name: str = "", strict: bool = False, survival: Optional[np.ndarray] = None, ) -> Tuple[bool, List[str]]: """Run basic sanity checks on CIF arrays. Args: cause_cif: (N, K, H) horizons: length H tol: tolerance for inequalities name: model name for messages strict: if True, raise ValueError on first failure survival: optional (N, H) survival values at the same horizons Returns: (integrity_ok, notes) """ notes: List[str] = [] model_tag = f"[{name}] " if name else "" def _fail(msg: str) -> None: full = model_tag + msg if strict: raise ValueError(full) print("WARNING:", full) notes.append(msg) cif = np.asarray(cause_cif) if cif.ndim != 3: _fail(f"integrity: expected cause_cif ndim=3, got {cif.ndim}") return False, notes N, K, H = cif.shape if H != len(horizons): _fail( f"integrity: horizon length mismatch (H={H}, len(horizons)={len(horizons)})") # (5) Finite if not np.isfinite(cif).all(): _fail("integrity: non-finite values (NaN/Inf) in cause_cif") # (1) Range cmin = float(np.nanmin(cif)) cmax = float(np.nanmax(cif)) if cmin < -tol: _fail(f"integrity: range min={cmin} < -tol={-tol}") if cmax > 1.0 + tol: _fail(f"integrity: range max={cmax} > 1+tol={1.0+tol}") # (2) Monotonicity in horizons (per n,k) diffs = np.diff(cif, axis=2) if diffs.size > 0: if np.nanmin(diffs) < -tol: _fail("integrity: monotonicity violated (found negative diff along horizons)") # (3) Probability mass: sum_k CIF <= 1 mass = np.sum(cif, axis=1) # (N,H) mass_max = float(np.nanmax(mass)) if mass_max > 1.0 + tol: _fail(f"integrity: probability mass exceeds 1 (max sum_k={mass_max})") # (4) Conservation with survival, if provided if survival is None: warn = "integrity: survival not provided; skipping conservation check" if strict: # still skip (requested behavior), but keep message for context notes.append(warn) else: print("WARNING:", model_tag + warn) notes.append(warn) else: s = np.asarray(survival, dtype=float) if s.shape != (N, H): _fail( f"integrity: survival shape mismatch (got {s.shape}, expected {(N, H)})") else: recon = 1.0 - s err = np.abs(recon - mass) # Discrete-time should be very tight; exponential may accumulate slightly more numerical error. tol_cons = max(float(tol), 1e-4) if float(np.nanmax(err)) > tol_cons: _fail( f"integrity: conservation violated (max |(1-surv)-sum_cif|={float(np.nanmax(err))}, tol={tol_cons})") ok = len([n for n in notes if not n.endswith( "skipping conservation check")]) == 0 return ok, notes # ============================================================ # Metrics # ============================================================ # --- Standard fast DeLong AUC variance + CI (ties handled via midranks) --- def compute_midrank(x: np.ndarray) -> np.ndarray: x = np.asarray(x, dtype=float) order = np.argsort(x) z = x[order] n = x.shape[0] t = np.zeros(n, dtype=float) i = 0 while i < n: j = i while j < n and z[j] == z[i]: j += 1 t[i:j] = 0.5 * (i + j - 1) + 1.0 i = j out = np.empty(n, dtype=float) out[order] = t return out def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]: """Fast DeLong method for computing AUC covariance. predictions_sorted_transposed: shape (n_classifiers, n_examples) with positive examples first. """ preds = np.asarray(predictions_sorted_transposed, dtype=float) m = int(label_1_count) n = int(preds.shape[1] - m) if m <= 0 or n <= 0: return np.array([float("nan")]), np.array([[float("nan")]]) pos = preds[:, :m] neg = preds[:, m:] tx = np.array([compute_midrank(x) for x in pos]) ty = np.array([compute_midrank(x) for x in neg]) tz = np.array([compute_midrank(x) for x in preds]) aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n) v01 = (tz[:, :m] - tx) / n v10 = 1.0 - (tz[:, m:] - ty) / m if v01.shape[0] > 1: sx = np.cov(v01) sy = np.cov(v10) else: # Single-classifier case: compute row-wise variance (do not flatten). var_v01 = float(np.var(v01, axis=1, ddof=1)[0]) var_v10 = float(np.var(v10, axis=1, ddof=1)[0]) sx = np.array([[var_v01]]) sy = np.array([[var_v10]]) delong_cov = sx / m + sy / n return aucs, delong_cov def calc_auc_variance(ground_truth: np.ndarray, predictions: np.ndarray) -> Tuple[float, float]: y = np.asarray(ground_truth, dtype=int) p = np.asarray(predictions, dtype=float) if y.ndim != 1 or p.ndim != 1 or y.shape[0] != p.shape[0]: raise ValueError("calc_auc_variance expects 1D arrays of equal length") m = int(np.sum(y == 1)) n = int(np.sum(y == 0)) if m == 0 or n == 0: return float("nan"), float("nan") order = np.argsort(-y) # positives first preds_sorted = p[order] aucs, cov = fastDeLong(preds_sorted[np.newaxis, :], m) auc = float(aucs[0]) var = float(cov[0, 0]) return auc, var def delong_ci(ground_truth: np.ndarray, predictions: np.ndarray, alpha: float = 0.95) -> Tuple[float, float, float]: """Return (auc, ci_low, ci_high) using DeLong variance and normal CI.""" auc, var = calc_auc_variance(ground_truth, predictions) if not np.isfinite(var) or var <= 0: print("WARNING: DeLong variance is non-positive or NaN; CI set to NaN") return float(auc), float("nan"), float("nan") sd = math.sqrt(var) z = statistics.NormalDist().inv_cdf(1.0 - (1.0 - float(alpha)) / 2.0) lo = max(0.0, auc - z * sd) hi = min(1.0, auc + z * sd) return float(auc), float(lo), float(hi) def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float: """Rank-based ROC AUC via Mann–Whitney U statistic (ties handled by midranks). Returns NaN for degenerate labels. """ y = np.asarray(y_true, dtype=int) s = np.asarray(y_score, dtype=float) if y.ndim != 1 or s.ndim != 1 or y.shape[0] != s.shape[0]: raise ValueError("roc_auc_rank expects 1D arrays of equal length") m = int(np.sum(y == 1)) n = int(np.sum(y == 0)) if m == 0 or n == 0: return float("nan") ranks = compute_midrank(s) sum_pos = float(np.sum(ranks[y == 1])) auc = (sum_pos - m * (m + 1) / 2.0) / (m * n) return float(auc) def bootstrap_auc_ci( scores: np.ndarray, labels: np.ndarray, n_bootstrap: int, alpha: float = 0.95, seed: int = 0, ) -> Tuple[float, float, float]: """Bootstrap CI for ROC AUC (percentile).""" rng = np.random.default_rng(int(seed)) scores = np.asarray(scores, dtype=float) labels = np.asarray(labels, dtype=int) n = labels.shape[0] if n == 0 or np.all(labels == labels[0]): print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN") return float("nan"), float("nan"), float("nan") auc_full = roc_auc_rank(labels, scores) if not np.isfinite(auc_full): print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN") return float("nan"), float("nan"), float("nan") aucs: List[float] = [] for _ in range(int(n_bootstrap)): idx = rng.integers(0, n, size=n) yb = labels[idx] if np.all(yb == yb[0]): continue pb = scores[idx] auc = roc_auc_rank(yb, pb) if np.isfinite(auc): aucs.append(float(auc)) if len(aucs) < 10: print("WARNING: bootstrap AUC CI has too few valid resamples; CI set to NaN") return float(auc_full), float("nan"), float("nan") lo_q = (1.0 - float(alpha)) / 2.0 hi_q = 1.0 - lo_q lo = float(np.quantile(aucs, lo_q)) hi = float(np.quantile(aucs, hi_q)) return float(auc_full), lo, hi def brier_score(p: np.ndarray, y: np.ndarray) -> float: p = np.asarray(p, dtype=float) y = np.asarray(y, dtype=float) return float(np.mean((p - y) ** 2)) def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]: p = np.asarray(p, dtype=float) y = np.asarray(y, dtype=float) # guard if p.size == 0: return {"bins": [], "ece": float("nan"), "ici": float("nan")} edges = np.quantile(p, np.linspace(0.0, 1.0, n_bins + 1)) # make strictly increasing where possible edges[0] = -np.inf edges[-1] = np.inf bins = [] ece = 0.0 ici_accum = 0.0 n = p.shape[0] for i in range(n_bins): mask = (p > edges[i]) & (p <= edges[i + 1]) if not np.any(mask): continue p_mean = float(np.mean(p[mask])) y_mean = float(np.mean(y[mask])) frac = float(np.mean(mask)) bins.append({"bin": i, "p_mean": p_mean, "y_mean": y_mean, "n": int(mask.sum())}) ece += frac * abs(p_mean - y_mean) ici_accum += abs(p_mean - y_mean) ici = ici_accum / max(len(bins), 1) return {"bins": bins, "ece": float(ece), "ici": float(ici)} def count_ever_after_context_anytime( loader: DataLoader, offset_years: float, n_disease: int, device: str, ) -> Tuple[np.ndarray, int]: """Count per-person ever-occurrence for each disease after the prediction context. Returns counts[k] = number of individuals with disease k at least once after context. """ counts = torch.zeros((n_disease,), dtype=torch.long, device=device) n_total_eval = 0 for batch in loader: event_seq, time_seq, cont_feats, cate_feats, sexes = batch event_seq = event_seq.to(device) time_seq = time_seq.to(device) keep, t_ctx, _ = select_context_indices( event_seq, time_seq, offset_years) if not keep.any(): continue n_total_eval += int(keep.sum().item()) event_seq = event_seq[keep] t_ctx = t_ctx[keep] B, L = event_seq.shape idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) future = (idxs > t_ctx.unsqueeze(1)) & ( event_seq >= 2) & (event_seq != 0) if not future.any(): continue b_idx, t_idx = future.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) # unique per (person, disease) to count per-person ever-occurrence key = b_idx.to(torch.long) * int(n_disease) + disease_ids uniq = torch.unique(key) uniq_disease = uniq % int(n_disease) counts.scatter_add_(0, uniq_disease, torch.ones_like( uniq_disease, dtype=torch.long)) return counts.detach().cpu().numpy(), int(n_total_eval) # ============================================================ # Evaluation core # ============================================================ def instantiate_model_and_head( cfg: Dict[str, Any], dataset: HealthDataset, device: str, checkpoint_path: str = "", ) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]: model_type = str(cfg["model_type"]) loss_type = str(cfg["loss_type"]) if loss_type == "exponential": out_dims = [dataset.n_disease] elif loss_type == "discrete_time_cif": bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) out_dims = [dataset.n_disease + 1, len(bin_edges)] else: raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}") if model_type == "delphi_fork": backbone = DelphiFork( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, ).to(device) elif model_type == "sap_delphi": # Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path. emb_path = cfg.get("pretrained_emb_path", None) if emb_path in {"", None}: emb_path = cfg.get("pretrained_emd_path", None) if emb_path in {"", None}: run_dir = os.path.dirname(os.path.abspath( checkpoint_path)) if checkpoint_path else "" print( f"WARNING: SapDelphi pretrained embedding path missing in config " f"(expected 'pretrained_emb_path' or 'pretrained_emd_path'). " f"checkpoint={checkpoint_path} run_dir={run_dir}" ) backbone = SapDelphi( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, pretrained_weights_path=emb_path, freeze_embeddings=True, ).to(device) else: raise ValueError(f"Unsupported model_type: {model_type}") head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device) bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) return backbone, head, loss_type, bin_edges @torch.no_grad() def predict_cifs_for_model( backbone: torch.nn.Module, head: torch.nn.Module, loss_type: str, bin_edges: Sequence[float], loader: DataLoader, device: str, offset_years: float, eval_horizons: Sequence[float], top_cause_ids: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Run model and produce: Returns: allcause_risk: (N,H) cause_cif: (N, topK, H) cif_full: (N, K, H) survival: (N, H) sex: (N,) y_allcause_tau: (N,H) y_cause_ever_anytime: (N, topK) y_cause_within_tau: (N, topK, H) y_cause_within_tau_max: (N, topK) NOTE: - y_cause_ever_anytime is Delphi2M-compatible case/control label. - y_cause_within_tau_* corresponds to within-horizon labels (kept for legacy/secondary AUC). """ backbone.eval() head.eval() # We will accumulate in CPU lists, then concat. allcause_list: List[np.ndarray] = [] cause_cif_list: List[np.ndarray] = [] cif_full_list: List[np.ndarray] = [] survival_list: List[np.ndarray] = [] sex_list: List[np.ndarray] = [] y_all_list: List[np.ndarray] = [] y_cause_ever_any_list: List[np.ndarray] = [] y_cause_within_list: List[np.ndarray] = [] y_cause_within_tau_max_list: List[np.ndarray] = [] tau_max = float(max(eval_horizons)) top_cause_ids_t = torch.tensor( top_cause_ids, dtype=torch.long, device=device) # Efficiency: pre-create horizons tensor once per model (on device) and vectorize comparisons. eval_horizons_t = torch.tensor( list(eval_horizons), device=device, dtype=torch.float32).view(1, -1) for batch in loader: event_seq, time_seq, cont_feats, cate_feats, sexes = batch event_seq = event_seq.to(device) time_seq = time_seq.to(device) cont_feats = cont_feats.to(device) cate_feats = cate_feats.to(device) sexes = sexes.to(device) keep, t_ctx, _ = select_context_indices( event_seq, time_seq, offset_years) if not keep.any(): continue # filter batch event_seq = event_seq[keep] time_seq = time_seq[keep] cont_feats = cont_feats[keep] cate_feats = cate_feats[keep] sexes_k = sexes[keep] t_ctx = t_ctx[keep] h = backbone(event_seq, time_seq, sexes_k, cont_feats, cate_feats) # (B,L,D) b = torch.arange(h.size(0), device=device) c = h[b, t_ctx] # (B,D) logits = head(c) if loss_type == "exponential": cif_full, survival = cifs_from_exponential_logits( logits, eval_horizons, return_survival=True) # (B,K,H), (B,H) elif loss_type == "discrete_time_cif": cif_full, survival = cifs_from_discrete_time_logits( # (B,K,H), (B,H) logits, bin_edges, eval_horizons, return_survival=True) else: raise ValueError(f"Unsupported loss_type: {loss_type}") allcause = cif_full.sum(dim=1) # (B,H) cause_cif = cif_full.index_select( dim=1, index=top_cause_ids_t) # (B,topK,H) # outcomes dt_next, _cause_next = next_event_after_context( event_seq, time_seq, t_ctx) y_all = (dt_next.view(-1, 1) <= eval_horizons_t).to(torch.float32) # Delphi2M-compatible ever label (does not depend on horizon) y_ever_any = multi_hot_ever_after_context_anytime( event_seq=event_seq, t_ctx=t_ctx, n_disease=int(cif_full.size(1)), ) y_ever_any_top = y_ever_any.index_select( dim=1, index=top_cause_ids_t).to(torch.float32) # Within-horizon labels for cause-specific CIF quality + legacy AUC n_disease = int(cif_full.size(1)) y_within_top = torch.stack( [ multi_hot_selected_causes_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau), cause_ids=top_cause_ids_t, n_disease=n_disease, ).to(torch.float32) for tau in eval_horizons ], dim=2, ) # (B,topK,H) y_within_tau_max_top = multi_hot_selected_causes_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=tau_max, cause_ids=top_cause_ids_t, n_disease=n_disease, ).to(torch.float32) allcause_list.append(allcause.detach().cpu().numpy()) cause_cif_list.append(cause_cif.detach().cpu().numpy()) cif_full_list.append(cif_full.detach().cpu().numpy()) survival_list.append(survival.detach().cpu().numpy()) sex_list.append(sexes_k.detach().cpu().numpy()) y_all_list.append(y_all.detach().cpu().numpy()) y_cause_ever_any_list.append(y_ever_any_top.detach().cpu().numpy()) y_cause_within_list.append(y_within_top.detach().cpu().numpy()) y_cause_within_tau_max_list.append( y_within_tau_max_top.detach().cpu().numpy()) if not allcause_list: raise RuntimeError( "No valid samples for evaluation (all batches filtered out by offset).") allcause_risk = np.concatenate(allcause_list, axis=0) cause_cif = np.concatenate(cause_cif_list, axis=0) cif_full = np.concatenate(cif_full_list, axis=0) survival = np.concatenate(survival_list, axis=0) sex = np.concatenate(sex_list, axis=0) y_allcause = np.concatenate(y_all_list, axis=0) y_cause_ever_any = np.concatenate(y_cause_ever_any_list, axis=0) y_cause_within = np.concatenate(y_cause_within_list, axis=0) y_cause_within_tau_max = np.concatenate(y_cause_within_tau_max_list, axis=0) return allcause_risk, cause_cif, cif_full, survival, sex, y_allcause, y_cause_ever_any, y_cause_within, y_cause_within_tau_max def pick_top_causes(y_ever: np.ndarray, top_k: int) -> np.ndarray: counts = y_ever.sum(axis=0) order = np.argsort(-counts) order = order[counts[order] > 0] return order[:top_k] def evaluate_one_model( model_name: str, allcause_risk: np.ndarray, cause_cif: np.ndarray, sex: np.ndarray, y_allcause: np.ndarray, y_cause_ever_anytime: np.ndarray, y_cause_within_tau: np.ndarray, y_cause_within_tau_max: np.ndarray, eval_horizons: Sequence[float], top_cause_ids: np.ndarray, out_rows: List[Dict[str, Any]], calib_rows: List[Dict[str, Any]], auc_ci_method: str, bootstrap_n: int, n_calib_bins: int = 10, ) -> None: H = len(eval_horizons) # Task B (all-cause): Brier + AUC + calibration per horizon for h_i, tau in enumerate(eval_horizons): p = allcause_risk[:, h_i] y = y_allcause[:, h_i] out_rows.append( { "model_name": model_name, "metric_name": "allcause_brier", "horizon": float(tau), "cause": "", "value": brier_score(p, y), "ci_low": "", "ci_high": "", } ) if auc_ci_method == "none": auc, lo, hi = float("nan"), float("nan"), float("nan") auc = float("nan") elif auc_ci_method == "bootstrap": auc, lo, hi = bootstrap_auc_ci( p, y, n_bootstrap=bootstrap_n, alpha=0.95) else: auc, lo, hi = delong_ci(y, p, alpha=0.95) out_rows.append( { "model_name": model_name, "metric_name": "allcause_auc", "horizon": float(tau), "cause": "", "value": auc, "ci_low": lo, "ci_high": hi, } ) cal = calibration_deciles(p, y, n_bins=n_calib_bins) out_rows.append( { "model_name": model_name, "metric_name": "allcause_ece", "horizon": float(tau), "cause": "", "value": cal["ece"], "ci_low": "", "ci_high": "", } ) out_rows.append( { "model_name": model_name, "metric_name": "allcause_ici", "horizon": float(tau), "cause": "", "value": cal["ici"], "ci_low": "", "ci_high": "", } ) # Write calibration bins into a separate CSV (always for all-cause). for binfo in cal.get("bins", []): calib_rows.append( { "model_name": model_name, "task": "all_cause", "horizon": float(tau), "cause_id": -1, "bin_index": int(binfo["bin"]), "p_mean": float(binfo["p_mean"]), "y_mean": float(binfo["y_mean"]), "n_in_bin": int(binfo["n"]), } ) # Stratification by sex for s_val in [0, 1]: m = sex == s_val if np.sum(m) < 10: continue p_s = p[m] y_s = y[m] if auc_ci_method == "none": auc_s, lo_s, hi_s = float("nan"), float("nan"), float("nan") elif auc_ci_method == "bootstrap": auc_s, lo_s, hi_s = bootstrap_auc_ci( p_s, y_s, n_bootstrap=bootstrap_n, alpha=0.95) else: auc_s, lo_s, hi_s = delong_ci(y_s, p_s, alpha=0.95) out_rows.append( { "model_name": model_name, "metric_name": f"allcause_auc_sex{s_val}", "horizon": float(tau), "cause": "", "value": auc_s, "ci_low": lo_s, "ci_high": hi_s, } ) # Task A (Delphi2M-compatible discrimination): per-cause AUC with EVER labels # case/control is defined by whether the disease appears ANYTIME after context. tau_max = float(max(eval_horizons)) p_tau_max = cause_cif[:, :, -1] # (N, topK) for j, cause_id in enumerate(top_cause_ids.tolist()): yk = y_cause_ever_anytime[:, j] pk = p_tau_max[:, j] if auc_ci_method == "none": auc, lo, hi = float("nan"), float("nan"), float("nan") elif auc_ci_method == "bootstrap": auc, lo, hi = bootstrap_auc_ci( pk, yk, n_bootstrap=bootstrap_n, alpha=0.95) else: auc, lo, hi = delong_ci(yk, pk, alpha=0.95) out_rows.append( { "model_name": model_name, "metric_name": "cause_auc_ever", "horizon": tau_max, "cause": int(cause_id), "value": auc, "ci_low": lo, "ci_high": hi, } ) # Keep the existing tau-window AUC as a separate metric (do not remove). for j, cause_id in enumerate(top_cause_ids.tolist()): yk = y_cause_within_tau_max[:, j] pk = p_tau_max[:, j] if auc_ci_method == "none": auc, lo, hi = float("nan"), float("nan"), float("nan") elif auc_ci_method == "bootstrap": auc, lo, hi = bootstrap_auc_ci( pk, yk, n_bootstrap=bootstrap_n, alpha=0.95) else: auc, lo, hi = delong_ci(yk, pk, alpha=0.95) out_rows.append( { "model_name": model_name, "metric_name": "cause_auc", "horizon": tau_max, "cause": int(cause_id), "value": auc, "ci_low": lo, "ci_high": hi, } ) # Task B additions: cause-specific Brier + calibration curves at tau=3.84 and 10.0 tau_targets = [3.84, 10.0] horizon_to_idx = {float(t): i for i, t in enumerate( [float(x) for x in eval_horizons])} for tau in tau_targets: if float(tau) not in horizon_to_idx: continue h_idx = horizon_to_idx[float(tau)] p_tau = cause_cif[:, :, h_idx] # (N, topK) y_tau = y_cause_within_tau[:, :, h_idx] # (N, topK) for j, cause_id in enumerate(top_cause_ids.tolist()): p = p_tau[:, j] y = y_tau[:, j] out_rows.append( { "model_name": model_name, "metric_name": "cause_brier", "horizon": float(tau), "cause": int(cause_id), "value": brier_score(p, y), "ci_low": "", "ci_high": "", } ) cal = calibration_deciles(p, y) out_rows.append( { "model_name": model_name, "metric_name": "cause_ece", "horizon": float(tau), "cause": int(cause_id), "value": cal["ece"], "ci_low": "", "ci_high": "", } ) out_rows.append( { "model_name": model_name, "metric_name": "cause_ici", "horizon": float(tau), "cause": int(cause_id), "value": cal["ici"], "ci_low": "", "ci_high": "", } ) # Write cause calibration bins into separate CSV only for tau targets. for binfo in cal.get("bins", []): calib_rows.append( { "model_name": model_name, "task": "cause_k", "horizon": float(tau), "cause_id": int(cause_id), "bin_index": int(binfo["bin"]), "p_mean": float(binfo["p_mean"]), "y_mean": float(binfo["y_mean"]), "n_in_bin": int(binfo["n"]), } ) def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None: fieldnames = [ "model_name", "task", "horizon", "cause_id", "bin_index", "p_mean", "y_mean", "n_in_bin", ] with open(path, "w", newline="") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for r in rows: w.writerow(r) def write_results_csv(path: str, rows: List[Dict[str, Any]]) -> None: fieldnames = [ "model_name", "metric_name", "horizon", "cause", "value", "ci_low", "ci_high", ] with open(path, "w", newline="") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for r in rows: w.writerow(r) def _make_eval_tag(split: str, offset_years: float) -> str: """Short tag for filenames written into run directories.""" off = f"{float(offset_years):.4f}".rstrip("0").rstrip(".") return f"{split}_offset{off}y" def main() -> int: ap = argparse.ArgumentParser( description="Unified downstream evaluation via CIFs") ap.add_argument("--models_json", type=str, required=True, help="Path to models list JSON") ap.add_argument("--data_prefix", type=str, default="ukb", help="Dataset prefix") ap.add_argument("--split", type=str, default="test", choices=["train", "val", "test", "all"], help="Which split to evaluate") ap.add_argument("--offset_years", type=float, default=0.5, help="Anti-leakage offset (years)") ap.add_argument("--eval_horizons", type=float, nargs="*", default=DEFAULT_EVAL_HORIZONS) ap.add_argument("--top_k_causes", type=int, default=50) ap.add_argument("--batch_size", type=int, default=128) ap.add_argument("--num_workers", type=int, default=0) ap.add_argument("--seed", type=int, default=123) ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") ap.add_argument("--out_csv", type=str, default="eval_results.csv") ap.add_argument("--out_meta_json", type=str, default="eval_meta.json") # Integrity checks ap.add_argument("--integrity_strict", action="store_true", default=False) ap.add_argument("--integrity_tol", type=float, default=1e-6) # AUC CI methods ap.add_argument( "--auc_ci_method", type=str, default="delong", choices=["delong", "bootstrap", "none"], ) ap.add_argument("--bootstrap_n", type=int, default=2000) args = ap.parse_args() set_deterministic(args.seed) specs = load_models_json(args.models_json) if not specs: raise ValueError("No models provided") # Determine top-K causes from the evaluation split only (model-agnostic). first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path) cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [ "bmi", "smoking", "alcohol"] dataset_for_top = HealthDataset( data_prefix=args.data_prefix, covariate_list=cov_list) subset_for_top = build_eval_subset( dataset_for_top, train_ratio=float(first_cfg.get("train_ratio", 0.7)), val_ratio=float(first_cfg.get("val_ratio", 0.15)), seed=int(first_cfg.get("random_seed", 42)), split=args.split, ) loader_top = DataLoader( subset_for_top, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=health_collate_fn, ) counts, n_total_eval = count_ever_after_context_anytime( loader=loader_top, offset_years=args.offset_years, n_disease=dataset_for_top.n_disease, device=args.device, ) order = np.argsort(-counts) order = order[counts[order] > 0] top_cause_ids = order[: args.top_k_causes] # Record top-cause counts under Delphi2M-compatible EVER label. top_causes_meta: List[Dict[str, Any]] = [] for k in top_cause_ids.tolist(): n_case = int(counts[int(k)]) top_causes_meta.append( { "cause_id": int(k), "n_case_ever": n_case, "n_control_ever": int(n_total_eval - n_case), "n_total_eval": int(n_total_eval), } ) rows: List[Dict[str, Any]] = [] calib_rows: List[Dict[str, Any]] = [] # Track per-model integrity status for meta JSON. integrity_meta: Dict[str, Any] = {} # Evaluate each model for spec in specs: run_dir = os.path.dirname(os.path.abspath(spec.checkpoint_path)) tag = _make_eval_tag(args.split, float(args.offset_years)) # Remember list offsets so we can write per-model slices to the model's run_dir. rows_start = len(rows) calib_start = len(calib_rows) cfg = load_train_config_for_checkpoint(spec.checkpoint_path) cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [ "bmi", "smoking", "alcohol"] dataset = HealthDataset( data_prefix=args.data_prefix, covariate_list=cov_list) subset = build_eval_subset( dataset, train_ratio=float(cfg.get("train_ratio", 0.7)), val_ratio=float(cfg.get("val_ratio", 0.15)), seed=int(cfg.get("random_seed", 42)), split=args.split, ) loader = DataLoader( subset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=health_collate_fn, ) backbone, head, loss_type, bin_edges = instantiate_model_and_head( cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path) ckpt = torch.load(spec.checkpoint_path, map_location=args.device) backbone.load_state_dict(ckpt["model_state_dict"], strict=True) head.load_state_dict(ckpt["head_state_dict"], strict=True) ( allcause_risk, cause_cif, cif_full, survival, sex, y_allcause, y_cause_ever_anytime, y_cause_within_tau, y_cause_within_tau_max, ) = predict_cifs_for_model( backbone, head, loss_type, bin_edges, loader, args.device, args.offset_years, args.eval_horizons, top_cause_ids, ) # CIF integrity checks before metrics. integrity_ok, integrity_notes = check_cif_integrity( cif_full, args.eval_horizons, tol=float(args.integrity_tol), name=spec.name, strict=bool(args.integrity_strict), survival=survival, ) integrity_meta[spec.name] = { "integrity_ok": bool(integrity_ok), "integrity_notes": integrity_notes, } evaluate_one_model( model_name=spec.name, allcause_risk=allcause_risk, cause_cif=cause_cif, sex=sex, y_allcause=y_allcause, y_cause_ever_anytime=y_cause_ever_anytime, y_cause_within_tau=y_cause_within_tau, y_cause_within_tau_max=y_cause_within_tau_max, eval_horizons=args.eval_horizons, top_cause_ids=top_cause_ids, out_rows=rows, calib_rows=calib_rows, auc_ci_method=str(args.auc_ci_method), bootstrap_n=int(args.bootstrap_n), ) # Optionally write top-cause counts into the main results CSV as metric rows. for tc in top_causes_meta: rows.append( { "model_name": spec.name, "metric_name": "topcause_n_case_ever", "horizon": "", "cause": int(tc["cause_id"]), "value": int(tc["n_case_ever"]), "ci_low": "", "ci_high": "", } ) rows.append( { "model_name": spec.name, "metric_name": "topcause_n_control_ever", "horizon": "", "cause": int(tc["cause_id"]), "value": int(tc["n_control_ever"]), "ci_low": "", "ci_high": "", } ) rows.append( { "model_name": spec.name, "metric_name": "topcause_n_total_eval", "horizon": "", "cause": int(tc["cause_id"]), "value": int(tc["n_total_eval"]), "ci_low": "", "ci_high": "", } ) # Write per-model results into the model's run directory. model_rows = rows[rows_start:] model_calib_rows = calib_rows[calib_start:] model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv") model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv") model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json") write_results_csv(model_out_csv, model_rows) write_calibration_bins_csv(model_calib_csv, model_calib_rows) model_meta = { "model_name": spec.name, "checkpoint_path": spec.checkpoint_path, "run_dir": run_dir, "split": args.split, "offset_years": args.offset_years, "eval_horizons": [float(x) for x in args.eval_horizons], "top_k_causes": int(args.top_k_causes), "top_cause_ids": top_cause_ids.tolist(), "top_causes": top_causes_meta, "integrity": {spec.name: integrity_meta.get(spec.name, {})}, "paths": { "results_csv": model_out_csv, "calibration_bins_csv": model_calib_csv, }, } with open(model_meta_json, "w") as f: json.dump(model_meta, f, indent=2) print(f"Wrote per-model results to {model_out_csv}") write_results_csv(args.out_csv, rows) # Write calibration curve points to a separate CSV. out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "." calib_csv_path = os.path.join(out_dir, "calibration_bins.csv") write_calibration_bins_csv(calib_csv_path, calib_rows) meta = { "split": args.split, "offset_years": args.offset_years, "eval_horizons": [float(x) for x in args.eval_horizons], "top_k_causes": int(args.top_k_causes), "top_cause_ids": top_cause_ids.tolist(), "top_causes": top_causes_meta, "integrity": integrity_meta, "notes": { "task_a_label": "Delphi2M-compatible: disease occurs ANYTIME after context (ever in remaining sequence)", "task_a_legacy_label": "Secondary: disease occurs within tau_max after context", "task_b_label": "all-cause event within horizon (equivalent to next disease event within horizon)", "warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.", }, } with open(args.out_meta_json, "w") as f: json.dump(meta, f, indent=2) print(f"Wrote {args.out_csv} with {len(rows)} rows") print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows") print(f"Wrote {args.out_meta_json}") return 0 if __name__ == "__main__": raise SystemExit(main())