import argparse import csv import json import math import os import random import statistics from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader, random_split from dataset import HealthDataset, health_collate_fn from model import DelphiFork, SapDelphi, SimpleHead # ============================================================ # Constants / defaults (aligned with evaluate_prompt.md) # ============================================================ DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")] DEFAULT_EVAL_HORIZONS = [0.72, 1.61, 3.84, 10.0] DAYS_PER_YEAR = 365.25 DEFAULT_DEATH_CAUSE_ID = 1256 # ============================================================ # Model specs # ============================================================ @dataclass(frozen=True) class ModelSpec: name: str model_type: str # delphi_fork | sap_delphi loss_type: str # exponential | discrete_time_cif full_cov: bool checkpoint_path: str # ============================================================ # Determinism # ============================================================ def set_deterministic(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # ============================================================ # Utilities # ============================================================ def _parse_bool(x: Any) -> bool: if isinstance(x, bool): return x s = str(x).strip().lower() if s in {"true", "1", "yes", "y"}: return True if s in {"false", "0", "no", "n"}: return False raise ValueError(f"Cannot parse boolean: {x!r}") def load_models_json(path: str) -> List[ModelSpec]: with open(path, "r") as f: data = json.load(f) if not isinstance(data, list): raise ValueError("models_json must be a list of model entries") specs: List[ModelSpec] = [] for row in data: specs.append( ModelSpec( name=str(row["name"]), model_type=str(row["model_type"]), loss_type=str(row["loss_type"]), full_cov=_parse_bool(row["full_cov"]), checkpoint_path=str(row["checkpoint_path"]), ) ) return specs def load_train_config_for_checkpoint(checkpoint_path: str) -> Dict[str, Any]: run_dir = os.path.dirname(os.path.abspath(checkpoint_path)) cfg_path = os.path.join(run_dir, "train_config.json") with open(cfg_path, "r") as f: cfg = json.load(f) return cfg def build_eval_subset( dataset: HealthDataset, train_ratio: float, val_ratio: float, seed: int, split: str, ): n_total = len(dataset) n_train = int(n_total * train_ratio) n_val = int(n_total * val_ratio) n_test = n_total - n_train - n_val train_ds, val_ds, test_ds = random_split( dataset, [n_train, n_val, n_test], generator=torch.Generator().manual_seed(seed), ) if split == "train": return train_ds if split == "val": return val_ds if split == "test": return test_ds if split == "all": return dataset raise ValueError("split must be one of: train, val, test, all") # ============================================================ # Context selection (anti-leakage) # ============================================================ def select_context_indices( event_seq: torch.Tensor, time_seq: torch.Tensor, offset_years: float, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Select per-sample prediction context index. IMPORTANT SEMANTICS: - The last observed token time is treated as the FOLLOW-UP END time. - We pick the last valid token with time <= (followup_end_time - offset). - We do NOT interpret followup_end_time as an event time. Returns: keep_mask: (B,) bool, which samples have a valid context t_ctx: (B,) long, index into sequence t_ctx_time: (B,) float, time (days) at context """ # valid tokens are event != 0 (padding is 0) valid = event_seq != 0 lengths = valid.sum(dim=1) last_idx = torch.clamp(lengths - 1, min=0) b = torch.arange(event_seq.size(0), device=event_seq.device) followup_end_time = time_seq[b, last_idx] t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR) eligible = valid & (time_seq <= t_cut.unsqueeze(1)) eligible_counts = eligible.sum(dim=1) keep = eligible_counts > 0 t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long) t_ctx_time = time_seq[b, t_ctx] return keep, t_ctx, t_ctx_time def next_event_after_context( event_seq: torch.Tensor, time_seq: torch.Tensor, t_ctx: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return next disease event after context. Returns: dt_years: (B,) float, time to next disease in years; +inf if none cause: (B,) long, disease id in [0,K) for next event; -1 if none """ B, L = event_seq.shape b = torch.arange(B, device=event_seq.device) t0 = time_seq[b, t_ctx] # Allow same-day events while excluding the context token itself. # We rely on time-sorted sequences and select the FIRST valid future event by index. idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) idx_min = torch.where( future, idxs, torch.full_like(idxs, L)).min(dim=1).values has = idx_min < L t_next = torch.where(has, idx_min, torch.zeros_like(idx_min)) t_next_time = time_seq[b, t_next] dt_days = t_next_time - t0 dt_years = dt_days / DAYS_PER_YEAR dt_years = torch.where( has, dt_years, torch.full_like(dt_years, float("inf"))) cause_token = event_seq[b, t_next] cause = (cause_token - 2).to(torch.long) cause = torch.where(has, cause, torch.full_like(cause, -1)) return dt_years, cause def multi_hot_ever_within_horizon( event_seq: torch.Tensor, time_seq: torch.Tensor, t_ctx: torch.Tensor, tau_years: float, n_disease: int, ) -> torch.Tensor: """Binary labels: disease k occurs within tau after context (any occurrence).""" B, L = event_seq.shape b = torch.arange(B, device=event_seq.device) t0 = time_seq[b, t_ctx] t1 = t0 + (tau_years * DAYS_PER_YEAR) idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) # Include same-day events after context, exclude any token at/before context index. in_window = ( (idxs > t_ctx.unsqueeze(1)) & (time_seq >= t0.unsqueeze(1)) & (time_seq <= t1.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) ) if not in_window.any(): return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) b_idx, t_idx = in_window.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) y[b_idx, disease_ids] = True return y def multi_hot_ever_after_context_anytime( event_seq: torch.Tensor, t_ctx: torch.Tensor, n_disease: int, ) -> torch.Tensor: """Binary labels: disease k occurs ANYTIME after the prediction context. This is Delphi2M-compatible for Task A case/control definition. Same-day events are included as long as they occur after the context token index. """ B, L = event_seq.shape idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) if not future.any(): return y b_idx, t_idx = future.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) y[b_idx, disease_ids] = True return y def multi_hot_selected_causes_within_horizon( event_seq: torch.Tensor, time_seq: torch.Tensor, t_ctx: torch.Tensor, tau_years: float, cause_ids: torch.Tensor, n_disease: int, ) -> torch.Tensor: """Labels for selected causes only: does cause k occur within tau after context?""" B, L = event_seq.shape device = event_seq.device b = torch.arange(B, device=device) t0 = time_seq[b, t_ctx] t1 = t0 + (tau_years * DAYS_PER_YEAR) idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) in_window = ( (idxs > t_ctx.unsqueeze(1)) & (time_seq >= t0.unsqueeze(1)) & (time_seq <= t1.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) ) out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device) if not in_window.any(): return out b_idx, t_idx = in_window.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) # Filter to selected causes via a boolean membership mask over the global disease space. selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device) selected[cause_ids] = True keep = selected[disease_ids] if not keep.any(): return out b_idx = b_idx[keep] disease_ids = disease_ids[keep] # Map disease_id -> local index in cause_ids # Build a lookup table (global disease space) where lookup[disease_id] = local_index lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device) lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device) local = lookup[disease_ids] out[b_idx, local] = True return out # ============================================================ # CIF conversion # ============================================================ def cifs_from_exponential_logits( logits: torch.Tensor, taus: Sequence[float], eps: float = 1e-6, return_survival: bool = False, ) -> torch.Tensor: """Convert exponential cause-specific logits -> CIFs at taus. logits: (B, K) returns: (B, K, H) or (cif, survival) if return_survival """ hazards = F.softplus(logits) + eps total = hazards.sum(dim=1, keepdim=True) # (B,1) taus_t = torch.tensor(list(taus), device=logits.device, dtype=hazards.dtype).view(1, 1, -1) total_h = total.unsqueeze(-1) # (B,1,1) # (1 - exp(-Lambda * tau)) one_minus_surv = 1.0 - torch.exp(-total_h * taus_t) frac = hazards / torch.clamp(total, min=eps) cif = frac.unsqueeze(-1) * one_minus_surv # (B,K,H) # If total==0, set to 0 cif = torch.where(total_h > 0, cif, torch.zeros_like(cif)) if not return_survival: return cif survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H) # Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors. nonzero = (total.squeeze(1) > 0).unsqueeze(1) survival = torch.where(nonzero, survival, torch.ones_like(survival)) return cif, survival def cifs_from_discrete_time_logits( logits: torch.Tensor, bin_edges: Sequence[float], taus: Sequence[float], return_survival: bool = False, ) -> torch.Tensor: """Convert discrete-time CIF logits -> CIFs at taus. logits: (B, K+1, n_bins+1) bin_edges: len=n_bins+1 (including 0 and inf) taus: subset of finite bin edges (recommended) returns: (B, K, H) or (cif, survival) if return_survival """ if logits.ndim != 3: raise ValueError("Expected logits shape (B, K+1, n_bins+1)") B, K_plus_1, n_bins_plus_1 = logits.shape K = K_plus_1 - 1 edges = [float(x) for x in bin_edges] # drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf finite_edges = [e for e in edges[1:] if math.isfinite(e)] n_bins = len(finite_edges) if n_bins_plus_1 != len(edges): raise ValueError("logits last dim must match len(bin_edges)") probs = torch.softmax(logits, dim=1) # (B, K+1, n_bins+1) # use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot) hazards = probs[:, :K, 1: 1 + n_bins] # (B,K,n_bins) p_comp = probs[:, K, 1: 1 + n_bins] # (B,n_bins) # survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v] ones = torch.ones((B, 1), device=logits.device, dtype=probs.dtype) cum = torch.cumprod(p_comp, dim=1) s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # (B, n_bins) cif_bins = torch.cumsum(s_prev.unsqueeze( 1) * hazards, dim=2) # (B,K,n_bins) # Robust mapping from tau -> edge index (floating-point safe). # taus are expected to align with bin edges, but may differ slightly due to parsing/serialization. finite_edges_arr = np.asarray(finite_edges, dtype=float) tau_to_idx: List[int] = [] for tau in taus: tau_f = float(tau) if not math.isfinite(tau_f): raise ValueError("taus must be finite for discrete-time CIF") diffs = np.abs(finite_edges_arr - tau_f) j = int(np.argmin(diffs)) if diffs[j] > 1e-6: raise ValueError( f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})" ) tau_to_idx.append(j) idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long) cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H) if not return_survival: return cif # Survival at each horizon = prod_{u <= idx[h]} p_comp[u] survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v] survival = survival_bins.index_select(dim=1, index=idx) # (B,H) return cif, survival # ============================================================ # CIF integrity checks # ============================================================ def check_cif_integrity( cause_cif: np.ndarray, horizons: Sequence[float], *, tol: float = 1e-6, name: str = "", strict: bool = False, survival: Optional[np.ndarray] = None, ) -> Tuple[bool, List[str]]: """Run basic sanity checks on CIF arrays. Args: cause_cif: (N, K, H) horizons: length H tol: tolerance for inequalities name: model name for messages strict: if True, raise ValueError on first failure survival: optional (N, H) survival values at the same horizons Returns: (integrity_ok, notes) """ notes: List[str] = [] model_tag = f"[{name}] " if name else "" def _fail(msg: str) -> None: full = model_tag + msg if strict: raise ValueError(full) print("WARNING:", full) notes.append(msg) cif = np.asarray(cause_cif) if cif.ndim != 3: _fail(f"integrity: expected cause_cif ndim=3, got {cif.ndim}") return False, notes N, K, H = cif.shape if H != len(horizons): _fail( f"integrity: horizon length mismatch (H={H}, len(horizons)={len(horizons)})") # (5) Finite if not np.isfinite(cif).all(): _fail("integrity: non-finite values (NaN/Inf) in cause_cif") # (1) Range cmin = float(np.nanmin(cif)) cmax = float(np.nanmax(cif)) if cmin < -tol: _fail(f"integrity: range min={cmin} < -tol={-tol}") if cmax > 1.0 + tol: _fail(f"integrity: range max={cmax} > 1+tol={1.0+tol}") # (2) Monotonicity in horizons (per n,k) diffs = np.diff(cif, axis=2) if diffs.size > 0: if np.nanmin(diffs) < -tol: _fail("integrity: monotonicity violated (found negative diff along horizons)") # (3) Probability mass: sum_k CIF <= 1 mass = np.sum(cif, axis=1) # (N,H) mass_max = float(np.nanmax(mass)) if mass_max > 1.0 + tol: _fail(f"integrity: probability mass exceeds 1 (max sum_k={mass_max})") # (4) Conservation with survival, if provided if survival is None: warn = "integrity: survival not provided; skipping conservation check" if strict: # still skip (requested behavior), but keep message for context notes.append(warn) else: print("WARNING:", model_tag + warn) notes.append(warn) else: s = np.asarray(survival, dtype=float) if s.shape != (N, H): _fail( f"integrity: survival shape mismatch (got {s.shape}, expected {(N, H)})") else: recon = 1.0 - s err = np.abs(recon - mass) # Discrete-time should be very tight; exponential may accumulate slightly more numerical error. tol_cons = max(float(tol), 1e-4) if float(np.nanmax(err)) > tol_cons: _fail( f"integrity: conservation violated (max |(1-surv)-sum_cif|={float(np.nanmax(err))}, tol={tol_cons})") ok = len([n for n in notes if not n.endswith( "skipping conservation check")]) == 0 return ok, notes # ============================================================ # Metrics # ============================================================ # --- Standard fast DeLong AUC variance + CI (ties handled via midranks) --- def compute_midrank(x: np.ndarray) -> np.ndarray: x = np.asarray(x, dtype=float) order = np.argsort(x) z = x[order] n = x.shape[0] t = np.zeros(n, dtype=float) i = 0 while i < n: j = i while j < n and z[j] == z[i]: j += 1 t[i:j] = 0.5 * (i + j - 1) + 1.0 i = j out = np.empty(n, dtype=float) out[order] = t return out def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]: """Fast DeLong method for computing AUC covariance. predictions_sorted_transposed: shape (n_classifiers, n_examples) with positive examples first. """ preds = np.asarray(predictions_sorted_transposed, dtype=float) m = int(label_1_count) n = int(preds.shape[1] - m) if m <= 0 or n <= 0: return np.array([float("nan")]), np.array([[float("nan")]]) pos = preds[:, :m] neg = preds[:, m:] tx = np.array([compute_midrank(x) for x in pos]) ty = np.array([compute_midrank(x) for x in neg]) tz = np.array([compute_midrank(x) for x in preds]) aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n) v01 = (tz[:, :m] - tx) / n v10 = 1.0 - (tz[:, m:] - ty) / m if v01.shape[0] > 1: sx = np.cov(v01) sy = np.cov(v10) else: # Single-classifier case: compute row-wise variance (do not flatten). var_v01 = float(np.var(v01, axis=1, ddof=1)[0]) var_v10 = float(np.var(v10, axis=1, ddof=1)[0]) sx = np.array([[var_v01]]) sy = np.array([[var_v10]]) delong_cov = sx / m + sy / n return aucs, delong_cov def calc_auc_variance(ground_truth: np.ndarray, predictions: np.ndarray) -> Tuple[float, float]: y = np.asarray(ground_truth, dtype=int) p = np.asarray(predictions, dtype=float) if y.ndim != 1 or p.ndim != 1 or y.shape[0] != p.shape[0]: raise ValueError("calc_auc_variance expects 1D arrays of equal length") m = int(np.sum(y == 1)) n = int(np.sum(y == 0)) if m == 0 or n == 0: return float("nan"), float("nan") order = np.argsort(-y) # positives first preds_sorted = p[order] aucs, cov = fastDeLong(preds_sorted[np.newaxis, :], m) auc = float(aucs[0]) var = float(cov[0, 0]) return auc, var def delong_ci(ground_truth: np.ndarray, predictions: np.ndarray, alpha: float = 0.95) -> Tuple[float, float, float]: """Return (auc, ci_low, ci_high) using DeLong variance and normal CI.""" auc, var = calc_auc_variance(ground_truth, predictions) if not np.isfinite(var) or var <= 0: print("WARNING: DeLong variance is non-positive or NaN; CI set to NaN") return float(auc), float("nan"), float("nan") sd = math.sqrt(var) z = statistics.NormalDist().inv_cdf(1.0 - (1.0 - float(alpha)) / 2.0) lo = max(0.0, auc - z * sd) hi = min(1.0, auc + z * sd) return float(auc), float(lo), float(hi) def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float: """Rank-based ROC AUC via Mann–Whitney U statistic (ties handled by midranks). Returns NaN for degenerate labels. """ y = np.asarray(y_true, dtype=int) s = np.asarray(y_score, dtype=float) if y.ndim != 1 or s.ndim != 1 or y.shape[0] != s.shape[0]: raise ValueError("roc_auc_rank expects 1D arrays of equal length") m = int(np.sum(y == 1)) n = int(np.sum(y == 0)) if m == 0 or n == 0: return float("nan") ranks = compute_midrank(s) sum_pos = float(np.sum(ranks[y == 1])) auc = (sum_pos - m * (m + 1) / 2.0) / (m * n) return float(auc) def bootstrap_auc_ci( scores: np.ndarray, labels: np.ndarray, n_bootstrap: int, alpha: float = 0.95, seed: int = 0, ) -> Tuple[float, float, float]: """Bootstrap CI for ROC AUC (percentile).""" rng = np.random.default_rng(int(seed)) scores = np.asarray(scores, dtype=float) labels = np.asarray(labels, dtype=int) n = labels.shape[0] if n == 0 or np.all(labels == labels[0]): print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN") return float("nan"), float("nan"), float("nan") auc_full = roc_auc_rank(labels, scores) if not np.isfinite(auc_full): print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN") return float("nan"), float("nan"), float("nan") aucs: List[float] = [] for _ in range(int(n_bootstrap)): idx = rng.integers(0, n, size=n) yb = labels[idx] if np.all(yb == yb[0]): continue pb = scores[idx] auc = roc_auc_rank(yb, pb) if np.isfinite(auc): aucs.append(float(auc)) if len(aucs) < 10: print("WARNING: bootstrap AUC CI has too few valid resamples; CI set to NaN") return float(auc_full), float("nan"), float("nan") lo_q = (1.0 - float(alpha)) / 2.0 hi_q = 1.0 - lo_q lo = float(np.quantile(aucs, lo_q)) hi = float(np.quantile(aucs, hi_q)) return float(auc_full), lo, hi def brier_score(p: np.ndarray, y: np.ndarray) -> float: p = np.asarray(p, dtype=float) y = np.asarray(y, dtype=float) return float(np.mean((p - y) ** 2)) def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]: p = np.asarray(p, dtype=float) y = np.asarray(y, dtype=float) # guard if p.size == 0: return {"bins": [], "ici": float("nan")} edges = np.quantile(p, np.linspace(0.0, 1.0, n_bins + 1)) # make strictly increasing where possible edges[0] = -np.inf edges[-1] = np.inf bins = [] ici_accum = 0.0 n = p.shape[0] for i in range(n_bins): mask = (p > edges[i]) & (p <= edges[i + 1]) if not np.any(mask): continue p_mean = float(np.mean(p[mask])) y_mean = float(np.mean(y[mask])) bins.append({"bin": i, "p_mean": p_mean, "y_mean": y_mean, "n": int(mask.sum())}) ici_accum += abs(p_mean - y_mean) ici = ici_accum / max(len(bins), 1) return {"bins": bins, "ici": float(ici)} def _safe_float(x: Any, default: float = float("nan")) -> float: try: return float(x) except Exception: return float(default) def _ensure_dir(path: str) -> None: os.makedirs(path, exist_ok=True) def load_cause_names(path: str = "labels.csv") -> Dict[int, str]: """Load 0-based cause_id -> name mapping. labels.csv is assumed to be one label per line, in disease-id order. """ if not os.path.exists(path): return {} mapping: Dict[int, str] = {} with open(path, "r", encoding="utf-8") as f: for i, line in enumerate(f): name = line.strip() if name: mapping[int(i)] = name return mapping def pick_focus_causes( *, counts_within_tau: Optional[np.ndarray], n_disease: int, death_cause_id: int = DEFAULT_DEATH_CAUSE_ID, k: int = 5, ) -> List[int]: """Pick focus causes for user-facing evaluation. Rule: 1) Always include death_cause_id first. 2) Then add K additional causes by descending event count if available. If counts_within_tau is None, fall back to descending cause_id coverage proxy. Notes: - counts_within_tau is expected to be shape (n_disease,). - Deterministic: ties broken by smaller cause id. """ n_disease_i = int(n_disease) if death_cause_id < 0 or death_cause_id >= n_disease_i: print( f"WARNING: death_cause_id={death_cause_id} out of range (n_disease={n_disease_i}); " "it will be omitted from focus causes." ) focus: List[int] = [] else: focus = [int(death_cause_id)] candidates = [i for i in range(n_disease_i) if i != int(death_cause_id)] if counts_within_tau is not None: c = np.asarray(counts_within_tau).astype(float) if c.shape[0] != n_disease_i: print( "WARNING: counts_within_tau length mismatch; falling back to coverage proxy ordering." ) counts_within_tau = None else: # Sort by (-count, cause_id) order = sorted(candidates, key=lambda i: (-float(c[i]), int(i))) order = [i for i in order if float(c[i]) > 0] focus.extend([int(i) for i in order[: int(k)]]) if counts_within_tau is None: # Fallback: deterministic coverage proxy (descending id, excluding death), then take K. # (Real coverage requires data; this path is mostly for robustness.) order = sorted(candidates, key=lambda i: (-int(i))) focus.extend([int(i) for i in order[: int(k)]]) # De-dup while preserving order seen = set() out: List[int] = [] for cid in focus: if cid not in seen: out.append(cid) seen.add(cid) return out def write_simple_csv(path: str, fieldnames: List[str], rows: List[Dict[str, Any]]) -> None: _ensure_dir(os.path.dirname(os.path.abspath(path)) or ".") with open(path, "w", newline="", encoding="utf-8") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for r in rows: w.writerow(r) def _sex_slices(sex: Optional[np.ndarray]) -> List[Tuple[str, Optional[np.ndarray]]]: """Return list of (sex_label, mask) slices including an 'all' slice. If sex is missing, returns only ('all', None). """ out: List[Tuple[str, Optional[np.ndarray]]] = [("all", None)] if sex is None: return out s = np.asarray(sex) if s.ndim != 1: return out for val in [0, 1]: m = (s == val) if int(np.sum(m)) > 0: out.append((str(val), m)) return out def _quantile_edges(p: np.ndarray, q: int) -> np.ndarray: edges = np.quantile(p, np.linspace(0.0, 1.0, int(q) + 1)) edges = np.asarray(edges, dtype=float) edges[0] = -np.inf edges[-1] = np.inf return edges def compute_risk_stratification_bins( p: np.ndarray, y: np.ndarray, *, q_default: int = 10, ) -> Tuple[int, List[Dict[str, Any]], Dict[str, Any]]: """Compute quantile-based risk strata and a compact summary.""" p = np.asarray(p, dtype=float) y = np.asarray(y, dtype=float) n = int(p.shape[0]) if n == 0: return 0, [], { "y_overall": float("nan"), "top_decile_y_rate": float("nan"), "bottom_half_y_rate": float("nan"), "lift_top10_vs_bottom50": float("nan"), "slope_pred_vs_obs": float("nan"), } # Choose quantiles robustly. q = int(q_default) if n < 200: q = 5 edges = _quantile_edges(p, q) y_overall = float(np.mean(y)) bin_rows: List[Dict[str, Any]] = [] p_means: List[float] = [] y_rates: List[float] = [] n_bins: List[int] = [] for i in range(q): mask = (p > edges[i]) & (p <= edges[i + 1]) nb = int(np.sum(mask)) if nb == 0: # Keep the row for consistent plotting; set NaNs. bin_rows.append( { "q": int(i + 1), "n_bin": 0, "p_mean": float("nan"), "y_rate": float("nan"), "y_overall": y_overall, "lift_vs_overall": float("nan"), } ) continue p_mean = float(np.mean(p[mask])) y_rate = float(np.mean(y[mask])) lift = float(y_rate / y_overall) if y_overall > 0 else float("nan") bin_rows.append( { "q": int(i + 1), "n_bin": nb, "p_mean": p_mean, "y_rate": y_rate, "y_overall": y_overall, "lift_vs_overall": lift, } ) p_means.append(p_mean) y_rates.append(y_rate) n_bins.append(nb) # Summary top_mask = (p > edges[q - 1]) & (p <= edges[q]) bot_half_mask = (p > edges[0]) & (p <= edges[q // 2]) top_y = float(np.mean(y[top_mask])) if int( np.sum(top_mask)) > 0 else float("nan") bot_y = float(np.mean(y[bot_half_mask])) if int( np.sum(bot_half_mask)) > 0 else float("nan") lift_top_vs_bottom = float(top_y / bot_y) if (np.isfinite(top_y) and np.isfinite(bot_y) and bot_y > 0) else float("nan") slope = float("nan") if len(p_means) >= 2: # Weighted least squares slope of y_rate ~ p_mean. x = np.asarray(p_means, dtype=float) yy = np.asarray(y_rates, dtype=float) w = np.asarray(n_bins, dtype=float) xm = float(np.average(x, weights=w)) ym = float(np.average(yy, weights=w)) denom = float(np.sum(w * (x - xm) ** 2)) if denom > 0: slope = float(np.sum(w * (x - xm) * (yy - ym)) / denom) summary = { "y_overall": y_overall, "top_decile_y_rate": top_y, "bottom_half_y_rate": bot_y, "lift_top10_vs_bottom50": lift_top_vs_bottom, "slope_pred_vs_obs": slope, } return q, bin_rows, summary def compute_capture_points( p: np.ndarray, y: np.ndarray, k_pcts: Sequence[int], ) -> List[Dict[str, Any]]: p = np.asarray(p, dtype=float) y = np.asarray(y, dtype=float) n = int(p.shape[0]) if n == 0: return [] order = np.argsort(-p) y_sorted = y[order] events_total = float(np.sum(y_sorted)) rows: List[Dict[str, Any]] = [] for k in k_pcts: kf = float(k) n_targeted = int(math.ceil(n * kf / 100.0)) n_targeted = max(1, min(n_targeted, n)) events_targeted = float(np.sum(y_sorted[:n_targeted])) capture = float(events_targeted / events_total) if events_total > 0 else float("nan") precision = float(events_targeted / float(n_targeted)) rows.append( { "k_pct": int(k), "n_targeted": int(n_targeted), "events_targeted": float(events_targeted), "events_total": float(events_total), "event_capture_rate": capture, "precision_in_targeted": precision, } ) return rows def make_horizon_groups(horizons: Sequence[float]) -> Tuple[List[Dict[str, Any]], Dict[float, str], str]: """Bucketize horizons into short/medium/long using the continuous-horizon rule.""" uniq = sorted({float(h) for h in horizons}) mapping: Dict[float, str] = {} rows: List[Dict[str, Any]] = [] # First 4 short, next 4 medium, rest long. for i, h in enumerate(uniq): if i < 4: g, gr = "short", 1 elif i < 8: g, gr = "medium", 2 else: g, gr = "long", 3 mapping[float(h)] = g rows.append({"horizon": float(h), "group": g, "group_rank": int(gr)}) method = "continuous_unique_horizons_first4_next4_rest" return rows, mapping, method def count_occurs_within_horizon( loader: DataLoader, offset_years: float, tau_years: float, n_disease: int, device: str, ) -> Tuple[np.ndarray, int]: """Count per-person occurrence within tau after the prediction context. Returns counts[k] = number of individuals with disease k at least once in (t_ctx, t_ctx+tau]. """ counts = torch.zeros((n_disease,), dtype=torch.long, device=device) n_total_eval = 0 for batch in loader: event_seq, time_seq, cont_feats, cate_feats, sexes = batch event_seq = event_seq.to(device) time_seq = time_seq.to(device) keep, t_ctx, _ = select_context_indices( event_seq, time_seq, offset_years) if not keep.any(): continue n_total_eval += int(keep.sum().item()) event_seq = event_seq[keep] time_seq = time_seq[keep] t_ctx = t_ctx[keep] B, L = event_seq.shape b = torch.arange(B, device=device) t0 = time_seq[b, t_ctx] t1 = t0 + (float(tau_years) * DAYS_PER_YEAR) idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) in_window = ( (idxs > t_ctx.unsqueeze(1)) & (time_seq >= t0.unsqueeze(1)) & (time_seq <= t1.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) ) if not in_window.any(): continue b_idx, t_idx = in_window.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) # unique per (person, disease) to count per-person within-window occurrence key = b_idx.to(torch.long) * int(n_disease) + disease_ids uniq = torch.unique(key) uniq_disease = uniq % int(n_disease) counts.scatter_add_(0, uniq_disease, torch.ones_like( uniq_disease, dtype=torch.long)) return counts.detach().cpu().numpy(), int(n_total_eval) # ============================================================ # Evaluation core # ============================================================ def instantiate_model_and_head( cfg: Dict[str, Any], dataset: HealthDataset, device: str, checkpoint_path: str = "", ) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]: model_type = str(cfg["model_type"]) loss_type = str(cfg["loss_type"]) if loss_type == "exponential": out_dims = [dataset.n_disease] elif loss_type == "discrete_time_cif": bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) out_dims = [dataset.n_disease + 1, len(bin_edges)] else: raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}") if model_type == "delphi_fork": backbone = DelphiFork( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, ).to(device) elif model_type == "sap_delphi": # Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path. emb_path = cfg.get("pretrained_emb_path", None) if emb_path in {"", None}: emb_path = cfg.get("pretrained_emd_path", None) if emb_path in {"", None}: run_dir = os.path.dirname(os.path.abspath( checkpoint_path)) if checkpoint_path else "" print( f"WARNING: SapDelphi pretrained embedding path missing in config " f"(expected 'pretrained_emb_path' or 'pretrained_emd_path'). " f"checkpoint={checkpoint_path} run_dir={run_dir}" ) backbone = SapDelphi( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, pretrained_weights_path=emb_path, freeze_embeddings=True, ).to(device) else: raise ValueError(f"Unsupported model_type: {model_type}") head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device) bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) return backbone, head, loss_type, bin_edges @torch.no_grad() def predict_cifs_for_model( backbone: torch.nn.Module, head: torch.nn.Module, loss_type: str, bin_edges: Sequence[float], loader: DataLoader, device: str, offset_years: float, eval_horizons: Sequence[float], top_cause_ids: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Run model and produce cause-specific, time-dependent CIF outputs. Returns: cause_cif: (N, topK, H) cif_full: (N, K, H) survival: (N, H) y_cause_within_tau: (N, topK, H) NOTE: Evaluation is cause-specific and horizon-specific (multi-disease risk). """ backbone.eval() head.eval() # We will accumulate in CPU lists, then concat. cause_cif_list: List[np.ndarray] = [] cif_full_list: List[np.ndarray] = [] survival_list: List[np.ndarray] = [] y_cause_within_list: List[np.ndarray] = [] sex_list: List[np.ndarray] = [] top_cause_ids_t = torch.tensor( top_cause_ids, dtype=torch.long, device=device) for batch in loader: event_seq, time_seq, cont_feats, cate_feats, sexes = batch event_seq = event_seq.to(device) time_seq = time_seq.to(device) cont_feats = cont_feats.to(device) cate_feats = cate_feats.to(device) sexes = sexes.to(device) keep, t_ctx, _ = select_context_indices( event_seq, time_seq, offset_years) if not keep.any(): continue # filter batch event_seq = event_seq[keep] time_seq = time_seq[keep] cont_feats = cont_feats[keep] cate_feats = cate_feats[keep] sexes_k = sexes[keep] t_ctx = t_ctx[keep] h = backbone(event_seq, time_seq, sexes_k, cont_feats, cate_feats) # (B,L,D) b = torch.arange(h.size(0), device=device) c = h[b, t_ctx] # (B,D) logits = head(c) if loss_type == "exponential": cif_full, survival = cifs_from_exponential_logits( logits, eval_horizons, return_survival=True) # (B,K,H), (B,H) elif loss_type == "discrete_time_cif": cif_full, survival = cifs_from_discrete_time_logits( # (B,K,H), (B,H) logits, bin_edges, eval_horizons, return_survival=True) else: raise ValueError(f"Unsupported loss_type: {loss_type}") cause_cif = cif_full.index_select( dim=1, index=top_cause_ids_t) # (B,topK,H) # Within-horizon labels for cause-specific CIF quality + discrimination. n_disease = int(cif_full.size(1)) y_within_top = torch.stack( [ multi_hot_selected_causes_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau), cause_ids=top_cause_ids_t, n_disease=n_disease, ).to(torch.float32) for tau in eval_horizons ], dim=2, ) # (B,topK,H) cause_cif_list.append(cause_cif.detach().cpu().numpy()) cif_full_list.append(cif_full.detach().cpu().numpy()) survival_list.append(survival.detach().cpu().numpy()) y_cause_within_list.append(y_within_top.detach().cpu().numpy()) sex_list.append(sexes_k.detach().cpu().numpy()) if not cause_cif_list: raise RuntimeError( "No valid samples for evaluation (all batches filtered out by offset).") cause_cif = np.concatenate(cause_cif_list, axis=0) cif_full = np.concatenate(cif_full_list, axis=0) survival = np.concatenate(survival_list, axis=0) y_cause_within = np.concatenate(y_cause_within_list, axis=0) sex = np.concatenate( sex_list, axis=0) if sex_list else np.array([], dtype=int) return cause_cif, cif_full, survival, y_cause_within, sex def pick_top_causes(y_ever: np.ndarray, top_k: int) -> np.ndarray: counts = y_ever.sum(axis=0) order = np.argsort(-counts) order = order[counts[order] > 0] return order[:top_k] def evaluate_one_model( model_name: str, cause_cif: np.ndarray, y_cause_within_tau: np.ndarray, eval_horizons: Sequence[float], top_cause_ids: np.ndarray, out_rows: List[Dict[str, Any]], calib_rows: List[Dict[str, Any]], auc_ci_method: str, bootstrap_n: int, n_calib_bins: int = 10, ) -> None: # Cause-specific, time-dependent metrics per horizon. for h_i, tau in enumerate(eval_horizons): p_tau = cause_cif[:, :, h_i] # (N, topK) y_tau = y_cause_within_tau[:, :, h_i] # (N, topK) for j, cause_id in enumerate(top_cause_ids.tolist()): p = p_tau[:, j] y = y_tau[:, j] # Primary: CIF-based Brier score + ICI (calibration). out_rows.append( { "model_name": model_name, "metric_name": "cause_brier", "horizon": float(tau), "cause": int(cause_id), "value": brier_score(p, y), "ci_low": "", "ci_high": "", } ) cal = calibration_deciles(p, y, n_bins=n_calib_bins) out_rows.append( { "model_name": model_name, "metric_name": "cause_ici", "horizon": float(tau), "cause": int(cause_id), "value": cal["ici"], "ci_low": "", "ci_high": "", } ) # Secondary: discrimination via AUC at the same horizon. if auc_ci_method == "none": auc, lo, hi = float("nan"), float("nan"), float("nan") elif auc_ci_method == "bootstrap": auc, lo, hi = bootstrap_auc_ci( p, y, n_bootstrap=bootstrap_n, alpha=0.95) else: auc, lo, hi = delong_ci(y, p, alpha=0.95) out_rows.append( { "model_name": model_name, "metric_name": "cause_auc", "horizon": float(tau), "cause": int(cause_id), "value": auc, "ci_low": lo, "ci_high": hi, } ) # Calibration curve bins for this cause + horizon. for binfo in cal.get("bins", []): calib_rows.append( { "model_name": model_name, "task": "cause_k", "horizon": float(tau), "cause_id": int(cause_id), "bin_index": int(binfo["bin"]), "p_mean": float(binfo["p_mean"]), "y_mean": float(binfo["y_mean"]), "n_in_bin": int(binfo["n"]), } ) def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None: fieldnames = [ "model_name", "task", "horizon", "cause_id", "bin_index", "p_mean", "y_mean", "n_in_bin", ] with open(path, "w", newline="") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for r in rows: w.writerow(r) def write_results_csv(path: str, rows: List[Dict[str, Any]]) -> None: fieldnames = [ "model_name", "metric_name", "horizon", "cause", "value", "ci_low", "ci_high", ] with open(path, "w", newline="") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for r in rows: w.writerow(r) def _make_eval_tag(split: str, offset_years: float) -> str: """Short tag for filenames written into run directories.""" off = f"{float(offset_years):.4f}".rstrip("0").rstrip(".") return f"{split}_offset{off}y" def main() -> int: ap = argparse.ArgumentParser( description="Unified downstream evaluation via CIFs") ap.add_argument("--models_json", type=str, required=True, help="Path to models list JSON") ap.add_argument("--data_prefix", type=str, default="ukb", help="Dataset prefix") ap.add_argument("--split", type=str, default="test", choices=["train", "val", "test", "all"], help="Which split to evaluate") ap.add_argument("--offset_years", type=float, default=0.5, help="Anti-leakage offset (years)") ap.add_argument("--eval_horizons", type=float, nargs="*", default=DEFAULT_EVAL_HORIZONS) ap.add_argument("--top_k_causes", type=int, default=50) ap.add_argument("--batch_size", type=int, default=128) ap.add_argument("--num_workers", type=int, default=0) ap.add_argument("--seed", type=int, default=123) ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") ap.add_argument("--out_csv", type=str, default="eval_results.csv") ap.add_argument("--out_meta_json", type=str, default="eval_meta.json") # Integrity checks ap.add_argument("--integrity_strict", action="store_true", default=False) ap.add_argument("--integrity_tol", type=float, default=1e-6) # AUC CI methods ap.add_argument( "--auc_ci_method", type=str, default="delong", choices=["delong", "bootstrap", "none"], ) ap.add_argument("--bootstrap_n", type=int, default=2000) # Export settings for user-facing experiments ap.add_argument("--export_dir", type=str, default="eval_exports") ap.add_argument("--death_cause_id", type=int, default=DEFAULT_DEATH_CAUSE_ID) ap.add_argument("--focus_k", type=int, default=5, help="Additional non-death causes to include") ap.add_argument("--capture_k_pcts", type=int, nargs="*", default=[1, 5, 10, 20]) ap.add_argument( "--capture_curve_max_pct", type=int, default=50, help="If >0, also export a dense capture curve for k=1..max_pct", ) args = ap.parse_args() set_deterministic(args.seed) specs = load_models_json(args.models_json) if not specs: raise ValueError("No models provided") export_dir = str(args.export_dir) _ensure_dir(export_dir) cause_names = load_cause_names("labels.csv") # Determine top-K causes from the evaluation split only (model-agnostic), # aligned to time-dependent risk: occurrence within tau_max after context. first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path) cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [ "bmi", "smoking", "alcohol"] dataset_for_top = HealthDataset( data_prefix=args.data_prefix, covariate_list=cov_list) subset_for_top = build_eval_subset( dataset_for_top, train_ratio=float(first_cfg.get("train_ratio", 0.7)), val_ratio=float(first_cfg.get("val_ratio", 0.15)), seed=int(first_cfg.get("random_seed", 42)), split=args.split, ) loader_top = DataLoader( subset_for_top, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=health_collate_fn, ) tau_max = float(max(args.eval_horizons)) counts, n_total_eval = count_occurs_within_horizon( loader=loader_top, offset_years=args.offset_years, tau_years=tau_max, n_disease=dataset_for_top.n_disease, device=args.device, ) focus_causes = pick_focus_causes( counts_within_tau=counts, n_disease=int(dataset_for_top.n_disease), death_cause_id=int(args.death_cause_id), k=int(args.focus_k), ) top_cause_ids = np.asarray(focus_causes, dtype=int) # Export the chosen focus causes. focus_rows: List[Dict[str, Any]] = [] for r, cid in enumerate(focus_causes, start=1): row: Dict[str, Any] = {"cause": int(cid), "rank": int(r)} if cid in cause_names: row["cause_name"] = cause_names[cid] focus_rows.append(row) focus_fieldnames = ["cause", "rank"] + \ (["cause_name"] if any("cause_name" in r for r in focus_rows) else []) write_simple_csv(os.path.join(export_dir, "focus_causes.csv"), focus_fieldnames, focus_rows) # Metadata for focus causes (within tau_max). top_causes_meta: List[Dict[str, Any]] = [] for cid in focus_causes: n_case = int(counts[int(cid)]) if int(cid) < int(counts.shape[0]) else 0 top_causes_meta.append( { "cause_id": int(cid), "tau_years": float(tau_max), "n_case_within_tau": n_case, "n_control_within_tau": int(n_total_eval - n_case), "n_total_eval": int(n_total_eval), } ) # Horizon groups for Experiment 3 hg_rows, horizon_to_group, hg_method = make_horizon_groups( args.eval_horizons) write_simple_csv( os.path.join(export_dir, "horizon_groups.csv"), ["horizon", "group", "group_rank"], hg_rows, ) rows: List[Dict[str, Any]] = [] calib_rows: List[Dict[str, Any]] = [] # Experiment exports (accumulated across models) rs_bins_rows: List[Dict[str, Any]] = [] rs_sum_rows: List[Dict[str, Any]] = [] cap_points_rows: List[Dict[str, Any]] = [] cap_curve_rows: List[Dict[str, Any]] = [] cal_group_sum_rows: List[Dict[str, Any]] = [] cal_group_bins_rows: List[Dict[str, Any]] = [] # Track per-model integrity status for meta JSON. integrity_meta: Dict[str, Any] = {} # Evaluate each model for spec in specs: run_dir = os.path.dirname(os.path.abspath(spec.checkpoint_path)) tag = _make_eval_tag(args.split, float(args.offset_years)) # Remember list offsets so we can write per-model slices to the model's run_dir. rows_start = len(rows) calib_start = len(calib_rows) cfg = load_train_config_for_checkpoint(spec.checkpoint_path) # Identifiers for consistent exports model_id = str(spec.name) model_type = str( cfg.get("model_type", spec.model_type if hasattr(spec, "model_type") else "")) loss_type_id = str( cfg.get("loss_type", spec.loss_type if hasattr(spec, "loss_type") else "")) age_encoder = str(cfg.get("age_encoder", "")) cov_type = "full" if _parse_bool( cfg.get("full_cov", False)) else "partial" cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [ "bmi", "smoking", "alcohol"] dataset = HealthDataset( data_prefix=args.data_prefix, covariate_list=cov_list) subset = build_eval_subset( dataset, train_ratio=float(cfg.get("train_ratio", 0.7)), val_ratio=float(cfg.get("val_ratio", 0.15)), seed=int(cfg.get("random_seed", 42)), split=args.split, ) loader = DataLoader( subset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=health_collate_fn, ) backbone, head, loss_type, bin_edges = instantiate_model_and_head( cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path) ckpt = torch.load(spec.checkpoint_path, map_location=args.device) backbone.load_state_dict(ckpt["model_state_dict"], strict=True) head.load_state_dict(ckpt["head_state_dict"], strict=True) ( cause_cif, cif_full, survival, y_cause_within_tau, sex, ) = predict_cifs_for_model( backbone, head, loss_type, bin_edges, loader, args.device, args.offset_years, args.eval_horizons, top_cause_ids, ) # CIF integrity checks before metrics. integrity_ok, integrity_notes = check_cif_integrity( cif_full, args.eval_horizons, tol=float(args.integrity_tol), name=spec.name, strict=bool(args.integrity_strict), survival=survival, ) integrity_meta[spec.name] = { "integrity_ok": bool(integrity_ok), "integrity_notes": integrity_notes, } evaluate_one_model( model_name=spec.name, cause_cif=cause_cif, y_cause_within_tau=y_cause_within_tau, eval_horizons=args.eval_horizons, top_cause_ids=top_cause_ids, out_rows=rows, calib_rows=calib_rows, auc_ci_method=str(args.auc_ci_method), bootstrap_n=int(args.bootstrap_n), ) # ============================================================ # Experiment 1: Risk stratification bins + summary # ============================================================ for sex_label, sex_mask in _sex_slices(sex if sex.size else None): for h_i, tau in enumerate(args.eval_horizons): for j, cause_id in enumerate(top_cause_ids.tolist()): p = cause_cif[:, j, h_i] y = y_cause_within_tau[:, j, h_i] if sex_mask is not None: p = p[sex_mask] y = y[sex_mask] q_used, bin_rows, summary = compute_risk_stratification_bins( p, y, q_default=10) for br in bin_rows: rs_bins_rows.append( { "model_id": model_id, "model_type": model_type, "loss_type": loss_type_id, "age_encoder": age_encoder, "cov_type": cov_type, "cause": int(cause_id), "horizon": float(tau), "sex": sex_label, "q": int(br["q"]), "n_bin": int(br["n_bin"]), "p_mean": _safe_float(br["p_mean"]), "y_rate": _safe_float(br["y_rate"]), "y_overall": _safe_float(br["y_overall"]), "lift_vs_overall": _safe_float(br["lift_vs_overall"]), "q_total": int(q_used), } ) rs_sum_rows.append( { "model_id": model_id, "model_type": model_type, "loss_type": loss_type_id, "age_encoder": age_encoder, "cov_type": cov_type, "cause": int(cause_id), "horizon": float(tau), "sex": sex_label, "q_total": int(q_used), "top_decile_y_rate": _safe_float(summary["top_decile_y_rate"]), "bottom_half_y_rate": _safe_float(summary["bottom_half_y_rate"]), "lift_top10_vs_bottom50": _safe_float(summary["lift_top10_vs_bottom50"]), "slope_pred_vs_obs": _safe_float(summary["slope_pred_vs_obs"]), } ) # ============================================================ # Experiment 2: High-risk capture points (+ optional curve) # ============================================================ k_pcts = [int(x) for x in args.capture_k_pcts] curve_max = int(args.capture_curve_max_pct) curve_grid = list(range(1, curve_max + 1) ) if curve_max and curve_max > 0 else [] for sex_label, sex_mask in _sex_slices(sex if sex.size else None): for h_i, tau in enumerate(args.eval_horizons): for j, cause_id in enumerate(top_cause_ids.tolist()): p = cause_cif[:, j, h_i] y = y_cause_within_tau[:, j, h_i] if sex_mask is not None: p = p[sex_mask] y = y[sex_mask] for r in compute_capture_points(p, y, k_pcts): cap_points_rows.append( { "model_id": model_id, "model_type": model_type, "loss_type": loss_type_id, "age_encoder": age_encoder, "cov_type": cov_type, "cause": int(cause_id), "horizon": float(tau), "sex": sex_label, **r, } ) if curve_grid: for r in compute_capture_points(p, y, curve_grid): cap_curve_rows.append( { "model_id": model_id, "model_type": model_type, "loss_type": loss_type_id, "age_encoder": age_encoder, "cov_type": cov_type, "cause": int(cause_id), "horizon": float(tau), "sex": sex_label, **r, } ) # ============================================================ # Experiment 3: Short/Medium/Long horizon-group calibration # ============================================================ # Per-horizon metrics for grouping # Build a dict for quick access: (cause_id, horizon) -> (brier, ici) per_h: Dict[Tuple[int, float], Dict[str, float]] = {} for rr in rows[rows_start:]: if rr.get("model_name") != spec.name: continue if rr.get("metric_name") not in {"cause_brier", "cause_ici"}: continue try: cid = int(rr.get("cause")) except Exception: continue h = _safe_float(rr.get("horizon")) if not np.isfinite(h): continue key = (cid, float(h)) d = per_h.get(key, {}) d[str(rr.get("metric_name"))] = _safe_float(rr.get("value")) per_h[key] = d # Compute group summaries and pooled bins using the same quantile bins as exp1 (per slice). for sex_label, sex_mask in _sex_slices(sex if sex.size else None): for j, cause_id in enumerate(top_cause_ids.tolist()): # Decide Q per slice for pooled reliability curve n_slice = int(np.sum(sex_mask)) if sex_mask is not None else int( sex.shape[0]) q_pool = 10 if n_slice >= 200 else 5 # Collect per-horizon brier/ici values group_vals: Dict[str, Dict[str, List[float]]] = {"short": {"brier": [], "ici": [ ]}, "medium": {"brier": [], "ici": []}, "long": {"brier": [], "ici": []}} group_n_total: Dict[str, int] = { "short": 0, "medium": 0, "long": 0} # Pooled bins: group -> q -> accumulators pooled: Dict[str, Dict[int, Dict[str, float]]] = { "short": {}, "medium": {}, "long": {}} for h_i, tau in enumerate(args.eval_horizons): g = horizon_to_group.get(float(tau), "long") # brier/ici per horizon (already computed at full-sample level) d = per_h.get((int(cause_id), float(tau)), {}) brier_h = _safe_float(d.get("cause_brier")) ici_h = _safe_float(d.get("cause_ici")) if np.isfinite(brier_h): group_vals[g]["brier"].append(brier_h) if np.isfinite(ici_h): group_vals[g]["ici"].append(ici_h) # pooled reliability bins from raw p/y p = cause_cif[:, j, h_i] y = y_cause_within_tau[:, j, h_i] if sex_mask is not None: p = p[sex_mask] y = y[sex_mask] if p.size == 0: continue edges = _quantile_edges(p, q_pool) for qi in range(q_pool): m = (p > edges[qi]) & (p <= edges[qi + 1]) nb = int(np.sum(m)) if nb == 0: continue pm = float(np.mean(p[m])) yr = float(np.mean(y[m])) acc = pooled[g].get( qi + 1, {"n": 0.0, "p_sum": 0.0, "y_sum": 0.0}) acc["n"] += float(nb) acc["p_sum"] += float(nb) * pm acc["y_sum"] += float(nb) * yr pooled[g][qi + 1] = acc group_n_total[g] = max(group_n_total[g], int(p.size)) for g in ["short", "medium", "long"]: bvals = group_vals[g]["brier"] ivals = group_vals[g]["ici"] cal_group_sum_rows.append( { "model_id": model_id, "model_type": model_type, "loss_type": loss_type_id, "age_encoder": age_encoder, "cov_type": cov_type, "cause": int(cause_id), "sex": sex_label, "horizon_group": g, "brier_mean": float(np.mean(bvals)) if bvals else float("nan"), "brier_median": float(np.median(bvals)) if bvals else float("nan"), "ici_mean": float(np.mean(ivals)) if ivals else float("nan"), "ici_median": float(np.median(ivals)) if ivals else float("nan"), "n_total": int(group_n_total[g]), "horizon_grouping_method": hg_method, } ) for qi in range(1, q_pool + 1): acc = pooled[g].get(qi) if not acc or float(acc.get("n", 0.0)) <= 0: continue n_bin = float(acc["n"]) cal_group_bins_rows.append( { "model_id": model_id, "model_type": model_type, "loss_type": loss_type_id, "age_encoder": age_encoder, "cov_type": cov_type, "cause": int(cause_id), "sex": sex_label, "horizon_group": g, "q": int(qi), "n_bin": int(n_bin), "p_mean": float(acc["p_sum"] / n_bin), "y_rate": float(acc["y_sum"] / n_bin), "q_total": int(q_pool), "horizon_grouping_method": hg_method, } ) # Optionally write top-cause counts into the main results CSV as metric rows. for tc in top_causes_meta: rows.append( { "model_name": spec.name, "metric_name": "topcause_n_case_within_tau", "horizon": float(tc["tau_years"]), "cause": int(tc["cause_id"]), "value": int(tc["n_case_within_tau"]), "ci_low": "", "ci_high": "", } ) rows.append( { "model_name": spec.name, "metric_name": "topcause_n_control_within_tau", "horizon": float(tc["tau_years"]), "cause": int(tc["cause_id"]), "value": int(tc["n_control_within_tau"]), "ci_low": "", "ci_high": "", } ) rows.append( { "model_name": spec.name, "metric_name": "topcause_n_total_eval", "horizon": float(tc["tau_years"]), "cause": int(tc["cause_id"]), "value": int(tc["n_total_eval"]), "ci_low": "", "ci_high": "", } ) # Write per-model results into the model's run directory. model_rows = rows[rows_start:] model_calib_rows = calib_rows[calib_start:] model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv") model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv") model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json") write_results_csv(model_out_csv, model_rows) write_calibration_bins_csv(model_calib_csv, model_calib_rows) model_meta = { "model_name": spec.name, "checkpoint_path": spec.checkpoint_path, "run_dir": run_dir, "split": args.split, "offset_years": args.offset_years, "eval_horizons": [float(x) for x in args.eval_horizons], "top_k_causes": int(args.top_k_causes), "top_cause_ids": top_cause_ids.tolist(), "top_causes": top_causes_meta, "integrity": {spec.name: integrity_meta.get(spec.name, {})}, "paths": { "results_csv": model_out_csv, "calibration_bins_csv": model_calib_csv, }, } with open(model_meta_json, "w") as f: json.dump(model_meta, f, indent=2) print(f"Wrote per-model results to {model_out_csv}") write_results_csv(args.out_csv, rows) # Write calibration curve points to a separate CSV. out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "." calib_csv_path = os.path.join(out_dir, "calibration_bins.csv") write_calibration_bins_csv(calib_csv_path, calib_rows) # Write experiment exports write_simple_csv( os.path.join(export_dir, "risk_stratification_bins.csv"), [ "model_id", "model_type", "loss_type", "age_encoder", "cov_type", "cause", "horizon", "sex", "q", "n_bin", "p_mean", "y_rate", "y_overall", "lift_vs_overall", "q_total", ], rs_bins_rows, ) write_simple_csv( os.path.join(export_dir, "risk_stratification_summary.csv"), [ "model_id", "model_type", "loss_type", "age_encoder", "cov_type", "cause", "horizon", "sex", "q_total", "top_decile_y_rate", "bottom_half_y_rate", "lift_top10_vs_bottom50", "slope_pred_vs_obs", ], rs_sum_rows, ) write_simple_csv( os.path.join(export_dir, "lift_capture_points.csv"), [ "model_id", "model_type", "loss_type", "age_encoder", "cov_type", "cause", "horizon", "sex", "k_pct", "n_targeted", "events_targeted", "events_total", "event_capture_rate", "precision_in_targeted", ], cap_points_rows, ) if cap_curve_rows: write_simple_csv( os.path.join(export_dir, "lift_capture_curve.csv"), [ "model_id", "model_type", "loss_type", "age_encoder", "cov_type", "cause", "horizon", "sex", "k_pct", "n_targeted", "events_targeted", "events_total", "event_capture_rate", "precision_in_targeted", ], cap_curve_rows, ) write_simple_csv( os.path.join(export_dir, "calibration_groups_summary.csv"), [ "model_id", "model_type", "loss_type", "age_encoder", "cov_type", "cause", "sex", "horizon_group", "brier_mean", "brier_median", "ici_mean", "ici_median", "n_total", "horizon_grouping_method", ], cal_group_sum_rows, ) write_simple_csv( os.path.join(export_dir, "calibration_groups_bins.csv"), [ "model_id", "model_type", "loss_type", "age_encoder", "cov_type", "cause", "sex", "horizon_group", "q", "n_bin", "p_mean", "y_rate", "q_total", "horizon_grouping_method", ], cal_group_bins_rows, ) # Manifest markdown (stable, user-facing) manifest_path = os.path.join(export_dir, "eval_exports_manifest.md") with open(manifest_path, "w", encoding="utf-8") as f: f.write( "# Evaluation Exports Manifest\n\n" "This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). " "All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n" "## Files\n\n" "- focus_causes.csv: The deterministically selected focus causes (Death + top-K). Intended plot: bar of event support + label table.\n" "- horizon_groups.csv: Mapping from each horizon to short/medium/long buckets. Intended plot: annotate calibration comparisons.\n" "- risk_stratification_bins.csv: Quantile bins (deciles or quintiles) with predicted vs observed event rates and lift. Intended plot: reliability-by-risk-tier lines.\n" "- risk_stratification_summary.csv: Compact stratification summaries (top decile vs bottom half lift, slope). Intended plot: slide-friendly comparison table.\n" "- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n" "- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n" "- calibration_groups_summary.csv: Short/medium/long aggregated Brier/ICI (mean/median). Intended plot: grouped bar chart by horizon bucket.\n" "- calibration_groups_bins.csv: Pooled reliability points per horizon bucket (weighted by bin size). Intended plot: 3-panel reliability curves per model.\n" ) meta = { "split": args.split, "offset_years": args.offset_years, "eval_horizons": [float(x) for x in args.eval_horizons], "tau_max": float(tau_max), "top_k_causes": int(args.top_k_causes), "top_cause_ids": top_cause_ids.tolist(), "top_causes": top_causes_meta, "integrity": integrity_meta, "notes": { "label": "Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau])", "primary_metrics": "cause_brier (CIF-based) and cause_ici (calibration)", "secondary_metrics": "cause_auc (discrimination) with optional CI", "exclusions": "No all-cause aggregation; no next-event formulation; ECE not reported", "warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.", "exports_dir": export_dir, "focus_causes": focus_causes, "horizon_grouping_method": hg_method, }, } with open(args.out_meta_json, "w") as f: json.dump(meta, f, indent=2) print(f"Wrote {args.out_csv} with {len(rows)} rows") print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows") print(f"Wrote {args.out_meta_json}") return 0 if __name__ == "__main__": raise SystemExit(main())