diff --git a/evaluate_models.py b/evaluate_models.py deleted file mode 100644 index 6c354db..0000000 --- a/evaluate_models.py +++ /dev/null @@ -1,2297 +0,0 @@ -import argparse -import csv -import json -import math -import os -import random -import sys -import time -from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import dataclass -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple - -import numpy as np -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader, random_split - -from dataset import HealthDataset, health_collate_fn -from model import DelphiFork, SapDelphi, SimpleHead - - -# ============================================================ -# Constants / defaults (aligned with evaluate_prompt.md) -# ============================================================ -DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")] -DEFAULT_EVAL_HORIZONS = [0.72, 1.61, 3.84, 10.0] -DAYS_PER_YEAR = 365.25 -DEFAULT_DEATH_CAUSE_ID = 1256 - - -# ============================================================ -# Model specs -# ============================================================ -@dataclass(frozen=True) -class ModelSpec: - name: str - model_type: str # delphi_fork | sap_delphi - loss_type: str # exponential | discrete_time_cif | pwe_cif - full_cov: bool - checkpoint_path: str - - -# ============================================================ -# Determinism -# ============================================================ - -def set_deterministic(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - -# ============================================================ -# Utilities -# ============================================================ - -def _parse_bool(x: Any) -> bool: - if isinstance(x, bool): - return x - s = str(x).strip().lower() - if s in {"true", "1", "yes", "y"}: - return True - if s in {"false", "0", "no", "n"}: - return False - raise ValueError(f"Cannot parse boolean: {x!r}") - - -def load_models_json(path: str) -> List[ModelSpec]: - with open(path, "r") as f: - data = json.load(f) - if not isinstance(data, list): - raise ValueError("models_json must be a list of model entries") - - specs: List[ModelSpec] = [] - for row in data: - specs.append( - ModelSpec( - name=str(row["name"]), - model_type=str(row["model_type"]), - loss_type=str(row["loss_type"]), - full_cov=_parse_bool(row["full_cov"]), - checkpoint_path=str(row["checkpoint_path"]), - ) - ) - return specs - - -def load_train_config_for_checkpoint(checkpoint_path: str) -> Dict[str, Any]: - run_dir = os.path.dirname(os.path.abspath(checkpoint_path)) - cfg_path = os.path.join(run_dir, "train_config.json") - with open(cfg_path, "r") as f: - cfg = json.load(f) - return cfg - - -def build_eval_subset( - dataset: HealthDataset, - train_ratio: float, - val_ratio: float, - seed: int, - split: str, -): - n_total = len(dataset) - n_train = int(n_total * train_ratio) - n_val = int(n_total * val_ratio) - n_test = n_total - n_train - n_val - - train_ds, val_ds, test_ds = random_split( - dataset, - [n_train, n_val, n_test], - generator=torch.Generator().manual_seed(seed), - ) - - if split == "train": - return train_ds - if split == "val": - return val_ds - if split == "test": - return test_ds - if split == "all": - return dataset - raise ValueError("split must be one of: train, val, test, all") - - -# ============================================================ -# Context selection (anti-leakage) -# ============================================================ - -def select_context_indices( - event_seq: torch.Tensor, - time_seq: torch.Tensor, - offset_years: float, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Select per-sample prediction context index. - - IMPORTANT SEMANTICS: - - The last observed token time is treated as the FOLLOW-UP END time. - - We pick the last valid token with time <= (followup_end_time - offset). - - We do NOT interpret followup_end_time as an event time. - - Returns: - keep_mask: (B,) bool, which samples have a valid context - t_ctx: (B,) long, index into sequence - t_ctx_time: (B,) float, time (days) at context - """ - # valid tokens are event != 0 (padding is 0) - valid = event_seq != 0 - lengths = valid.sum(dim=1) - last_idx = torch.clamp(lengths - 1, min=0) - - b = torch.arange(event_seq.size(0), device=event_seq.device) - followup_end_time = time_seq[b, last_idx] - t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR) - - eligible = valid & (time_seq <= t_cut.unsqueeze(1)) - eligible_counts = eligible.sum(dim=1) - keep = eligible_counts > 0 - - t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long) - t_ctx_time = time_seq[b, t_ctx] - - return keep, t_ctx, t_ctx_time - - -def next_event_after_context( - event_seq: torch.Tensor, - time_seq: torch.Tensor, - t_ctx: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Return next disease event after context. - - Returns: - dt_years: (B,) float, time to next disease in years; +inf if none - cause: (B,) long, disease id in [0,K) for next event; -1 if none - """ - B, L = event_seq.shape - b = torch.arange(B, device=event_seq.device) - t0 = time_seq[b, t_ctx] - - # Allow same-day events while excluding the context token itself. - # We rely on time-sorted sequences and select the FIRST valid future event by index. - idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) - future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) - idx_min = torch.where( - future, idxs, torch.full_like(idxs, L)).min(dim=1).values - - has = idx_min < L - t_next = torch.where(has, idx_min, torch.zeros_like(idx_min)) - - t_next_time = time_seq[b, t_next] - dt_days = t_next_time - t0 - dt_years = dt_days / DAYS_PER_YEAR - dt_years = torch.where( - has, dt_years, torch.full_like(dt_years, float("inf"))) - - cause_token = event_seq[b, t_next] - cause = (cause_token - 2).to(torch.long) - cause = torch.where(has, cause, torch.full_like(cause, -1)) - - return dt_years, cause - - -def multi_hot_ever_within_horizon( - event_seq: torch.Tensor, - time_seq: torch.Tensor, - t_ctx: torch.Tensor, - tau_years: float, - n_disease: int, -) -> torch.Tensor: - """Binary labels: disease k occurs within tau after context (any occurrence).""" - B, L = event_seq.shape - b = torch.arange(B, device=event_seq.device) - t0 = time_seq[b, t_ctx] - t1 = t0 + (tau_years * DAYS_PER_YEAR) - - idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) - # Include same-day events after context, exclude any token at/before context index. - in_window = ( - (idxs > t_ctx.unsqueeze(1)) - & (time_seq >= t0.unsqueeze(1)) - & (time_seq <= t1.unsqueeze(1)) - & (event_seq >= 2) - & (event_seq != 0) - ) - - if not in_window.any(): - return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) - - b_idx, t_idx = in_window.nonzero(as_tuple=True) - disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) - - y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) - y[b_idx, disease_ids] = True - return y - - -def multi_hot_ever_after_context_anytime( - event_seq: torch.Tensor, - t_ctx: torch.Tensor, - n_disease: int, -) -> torch.Tensor: - """Binary labels: disease k occurs ANYTIME after the prediction context. - - This is Delphi2M-compatible for Task A case/control definition. - Same-day events are included as long as they occur after the context token index. - """ - B, L = event_seq.shape - idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) - future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) - - y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) - if not future.any(): - return y - - b_idx, t_idx = future.nonzero(as_tuple=True) - disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) - y[b_idx, disease_ids] = True - return y - - -def multi_hot_selected_causes_within_horizon( - event_seq: torch.Tensor, - time_seq: torch.Tensor, - t_ctx: torch.Tensor, - tau_years: float, - cause_ids: torch.Tensor, - n_disease: int, -) -> torch.Tensor: - """Labels for selected causes only: does cause k occur within tau after context?""" - B, L = event_seq.shape - device = event_seq.device - b = torch.arange(B, device=device) - t0 = time_seq[b, t_ctx] - t1 = t0 + (tau_years * DAYS_PER_YEAR) - - idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) - in_window = ( - (idxs > t_ctx.unsqueeze(1)) - & (time_seq >= t0.unsqueeze(1)) - & (time_seq <= t1.unsqueeze(1)) - & (event_seq >= 2) - & (event_seq != 0) - ) - - out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device) - if not in_window.any(): - return out - - b_idx, t_idx = in_window.nonzero(as_tuple=True) - disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) - - # Filter to selected causes via a boolean membership mask over the global disease space. - selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device) - selected[cause_ids] = True - keep = selected[disease_ids] - if not keep.any(): - return out - - b_idx = b_idx[keep] - disease_ids = disease_ids[keep] - - # Map disease_id -> local index in cause_ids - # Build a lookup table (global disease space) where lookup[disease_id] = local_index - lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device) - lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device) - local = lookup[disease_ids] - out[b_idx, local] = True - return out - - -# ============================================================ -# CIF conversion -# ============================================================ - -def cifs_from_exponential_logits( - logits: torch.Tensor, - taus: Sequence[float], - eps: float = 1e-6, - return_survival: bool = False, -) -> torch.Tensor: - """Convert exponential cause-specific logits -> CIFs at taus. - - logits: (B, K) - returns: (B, K, H) or (cif, survival) if return_survival - """ - hazards = F.softplus(logits) + eps - total = hazards.sum(dim=1, keepdim=True) # (B,1) - - taus_t = torch.tensor(list(taus), device=logits.device, - dtype=hazards.dtype).view(1, 1, -1) - total_h = total.unsqueeze(-1) # (B,1,1) - - # (1 - exp(-Lambda * tau)) - one_minus_surv = 1.0 - torch.exp(-total_h * taus_t) - frac = hazards / torch.clamp(total, min=eps) - cif = frac.unsqueeze(-1) * one_minus_surv # (B,K,H) - # If total==0, set to 0 - cif = torch.where(total_h > 0, cif, torch.zeros_like(cif)) - - if not return_survival: - return cif - - survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H) - # Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors. - nonzero = (total.squeeze(1) > 0).unsqueeze(1) - survival = torch.where(nonzero, survival, torch.ones_like(survival)) - return cif, survival - - -def cifs_from_discrete_time_logits( - logits: torch.Tensor, - bin_edges: Sequence[float], - taus: Sequence[float], - return_survival: bool = False, -) -> torch.Tensor: - """Convert discrete-time CIF logits -> CIFs at taus. - - logits: (B, K+1, n_bins+1) - bin_edges: len=n_bins+1 (including 0 and inf) - taus: subset of finite bin edges (recommended) - - returns: (B, K, H) or (cif, survival) if return_survival - """ - if logits.ndim != 3: - raise ValueError("Expected logits shape (B, K+1, n_bins+1)") - - B, K_plus_1, n_bins_plus_1 = logits.shape - K = K_plus_1 - 1 - - edges = [float(x) for x in bin_edges] - # drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf - finite_edges = [e for e in edges[1:] if math.isfinite(e)] - n_bins = len(finite_edges) - - if n_bins_plus_1 != len(edges): - raise ValueError("logits last dim must match len(bin_edges)") - - probs = torch.softmax(logits, dim=1) # (B, K+1, n_bins+1) - - # use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot) - hazards = probs[:, :K, 1: 1 + n_bins] # (B,K,n_bins) - p_comp = probs[:, K, 1: 1 + n_bins] # (B,n_bins) - - # survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v] - ones = torch.ones((B, 1), device=logits.device, dtype=probs.dtype) - cum = torch.cumprod(p_comp, dim=1) - s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # (B, n_bins) - - cif_bins = torch.cumsum(s_prev.unsqueeze( - 1) * hazards, dim=2) # (B,K,n_bins) - - # Robust mapping from tau -> edge index (floating-point safe). - # taus are expected to align with bin edges, but may differ slightly due to parsing/serialization. - finite_edges_arr = np.asarray(finite_edges, dtype=float) - tau_to_idx: List[int] = [] - for tau in taus: - tau_f = float(tau) - if not math.isfinite(tau_f): - raise ValueError("taus must be finite for discrete-time CIF") - diffs = np.abs(finite_edges_arr - tau_f) - j = int(np.argmin(diffs)) - if diffs[j] > 1e-6: - raise ValueError( - f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})" - ) - tau_to_idx.append(j) - - idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long) - cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H) - - if not return_survival: - return cif - - # Survival at each horizon = prod_{u <= idx[h]} p_comp[u] - survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v] - survival = survival_bins.index_select(dim=1, index=idx) # (B,H) - return cif, survival - - -def cifs_from_pwe_logits( - logits: torch.Tensor, - bin_edges: Sequence[float], - taus: Sequence[float], - eps: float = 1e-6, - return_survival: bool = False, -) -> torch.Tensor: - """Convert piecewise-exponential (PWE) hazard logits -> CIFs at taus. - - logits: (B, K, n_bins) # hazard logits per cause per bin - bin_edges: length n_bins+1, strictly increasing, finite last edge - taus: subset of finite bin edges (recommended) - - returns: (B, K, H) or (cif, survival) if return_survival - """ - if logits.ndim != 3: - raise ValueError("Expected logits shape (B, K, n_bins) for pwe_cif") - - edges = [float(x) for x in bin_edges] - if len(edges) < 2: - raise ValueError("bin_edges must have length >= 2") - if edges[0] != 0.0: - raise ValueError("bin_edges[0] must equal 0.0") - if not math.isfinite(edges[-1]): - raise ValueError( - "pwe_cif requires a finite last bin edge (no +inf). " - "If your training config uses +inf, drop it for PWE evaluation." - ) - - B, K, n_bins = logits.shape - if n_bins != (len(edges) - 1): - raise ValueError( - f"logits last dim n_bins={n_bins} must equal len(bin_edges)-1={len(edges)-1}" - ) - - # Convert logits -> hazards, then integrated hazards per bin. - hazards = F.softplus(logits) + eps # (B,K,n_bins) - dt_bins = torch.tensor( - [edges[i + 1] - edges[i] for i in range(n_bins)], - device=logits.device, - dtype=hazards.dtype, - ) # (n_bins,) - if not torch.isfinite(dt_bins).all() or not (dt_bins > 0).all(): - raise ValueError("All PWE bin widths must be finite and > 0") - - H_cause = hazards * dt_bins.view(1, 1, n_bins) # (B,K,n_bins) - H_total = H_cause.sum(dim=1) # (B,n_bins) - - # Survival at START of each bin u. - cum_total = torch.cumsum(H_total, dim=1) # (B,n_bins) - zeros = torch.zeros((B, 1), device=logits.device, dtype=hazards.dtype) - cum_prev = torch.cat([zeros, cum_total[:, :-1]], dim=1) # (B,n_bins) - S_prev = torch.exp(-cum_prev) # (B,n_bins) - - one_minus_surv_bin = 1.0 - torch.exp(-H_total) # (B,n_bins) - frac = H_cause / torch.clamp(H_total.unsqueeze(1), min=eps) # (B,K,n_bins) - - cif_incr = S_prev.unsqueeze(1) * frac * one_minus_surv_bin.unsqueeze(1) - cif_bins = torch.cumsum(cif_incr, dim=2) # (B,K,n_bins) at edges[1:] - - # Map tau -> edge index in edges[1:] - finite_edges = edges[1:] - finite_edges_arr = np.asarray(finite_edges, dtype=float) - tau_to_idx: List[int] = [] - for tau in taus: - tau_f = float(tau) - if not math.isfinite(tau_f): - raise ValueError("taus must be finite for pwe_cif") - diffs = np.abs(finite_edges_arr - tau_f) - j = int(np.argmin(diffs)) - if diffs[j] > 1e-6: - raise ValueError( - f"tau={tau_f} not close to any bin edge (min |edge-tau|={diffs[j]})" - ) - tau_to_idx.append(j) - - idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long) - cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H) - - if not return_survival: - return cif - - # Survival at each horizon is exp(-cum_total at that edge) - survival_bins = torch.exp(-cum_total) # (B,n_bins) - survival = survival_bins.index_select(dim=1, index=idx) # (B,H) - return cif, survival - - -# ============================================================ -# CIF integrity checks -# ============================================================ - -def check_cif_integrity( - cause_cif: np.ndarray, - horizons: Sequence[float], - *, - tol: float = 1e-6, - name: str = "", - strict: bool = False, - survival: Optional[np.ndarray] = None, -) -> Tuple[bool, List[str]]: - """Run basic sanity checks on CIF arrays. - - Args: - cause_cif: (N, K, H) - horizons: length H - tol: tolerance for inequalities - name: model name for messages - strict: if True, raise ValueError on first failure - survival: optional (N, H) survival values at the same horizons - - Returns: - (integrity_ok, notes) - """ - notes: List[str] = [] - model_tag = f"[{name}] " if name else "" - - def _fail(msg: str) -> None: - full = model_tag + msg - if strict: - raise ValueError(full) - print("WARNING:", full) - notes.append(msg) - - cif = np.asarray(cause_cif) - if cif.ndim != 3: - _fail(f"integrity: expected cause_cif ndim=3, got {cif.ndim}") - return False, notes - N, K, H = cif.shape - if H != len(horizons): - _fail( - f"integrity: horizon length mismatch (H={H}, len(horizons)={len(horizons)})") - - # (5) Finite - if not np.isfinite(cif).all(): - _fail("integrity: non-finite values (NaN/Inf) in cause_cif") - - # (1) Range - cmin = float(np.nanmin(cif)) - cmax = float(np.nanmax(cif)) - if cmin < -tol: - _fail(f"integrity: range min={cmin} < -tol={-tol}") - if cmax > 1.0 + tol: - _fail(f"integrity: range max={cmax} > 1+tol={1.0+tol}") - - # (2) Monotonicity in horizons (per n,k) - diffs = np.diff(cif, axis=2) - if diffs.size > 0: - if np.nanmin(diffs) < -tol: - _fail("integrity: monotonicity violated (found negative diff along horizons)") - - # (3) Probability mass: sum_k CIF <= 1 - mass = np.sum(cif, axis=1) # (N,H) - mass_max = float(np.nanmax(mass)) - if mass_max > 1.0 + tol: - _fail(f"integrity: probability mass exceeds 1 (max sum_k={mass_max})") - - # (4) Conservation with survival, if provided - if survival is None: - warn = "integrity: survival not provided; skipping conservation check" - if strict: - # still skip (requested behavior), but keep message for context - notes.append(warn) - else: - print("WARNING:", model_tag + warn) - notes.append(warn) - else: - s = np.asarray(survival, dtype=float) - if s.shape != (N, H): - _fail( - f"integrity: survival shape mismatch (got {s.shape}, expected {(N, H)})") - else: - recon = 1.0 - s - err = np.abs(recon - mass) - # Discrete-time should be very tight; exponential may accumulate slightly more numerical error. - tol_cons = max(float(tol), 1e-4) - if float(np.nanmax(err)) > tol_cons: - _fail( - f"integrity: conservation violated (max |(1-surv)-sum_cif|={float(np.nanmax(err))}, tol={tol_cons})") - - ok = len([n for n in notes if not n.endswith( - "skipping conservation check")]) == 0 - return ok, notes - - -# ============================================================ -# Metrics -# ============================================================ - -# --- Rank-based ROC AUC (ties handled via midranks) --- - -def compute_midrank(x: np.ndarray) -> np.ndarray: - """Vectorized midrank computation (ties -> average ranks).""" - x = np.asarray(x, dtype=float) - n = int(x.shape[0]) - if n == 0: - return np.asarray([], dtype=float) - - order = np.argsort(x, kind="mergesort") - z = x[order] - - # Find tie groups in sorted order. - diff = np.diff(z) - # boundaries includes 0 and n - boundaries = np.concatenate( - [np.array([0], dtype=int), np.nonzero(diff != 0) - [0] + 1, np.array([n], dtype=int)] - ) - starts = boundaries[:-1] - ends = boundaries[1:] - lens = ends - starts - - # Midrank for each group in 1-based rank space. - mids = 0.5 * (starts + ends - 1) + 1.0 - t_sorted = np.repeat(mids, lens).astype(float, copy=False) - - out = np.empty(n, dtype=float) - out[order] = t_sorted - return out - - -def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float: - """Rank-based ROC AUC via Mann–Whitney U statistic (ties handled by midranks). - - Returns NaN for degenerate labels. - """ - y = np.asarray(y_true, dtype=int) - s = np.asarray(y_score, dtype=float) - if y.ndim != 1 or s.ndim != 1 or y.shape[0] != s.shape[0]: - raise ValueError("roc_auc_rank expects 1D arrays of equal length") - m = int(np.sum(y == 1)) - n = int(np.sum(y == 0)) - if m == 0 or n == 0: - return float("nan") - - ranks = compute_midrank(s) - sum_pos = float(np.sum(ranks[y == 1])) - auc = (sum_pos - m * (m + 1) / 2.0) / (m * n) - return float(auc) - - -def brier_score(p: np.ndarray, y: np.ndarray) -> float: - p = np.asarray(p, dtype=float) - y = np.asarray(y, dtype=float) - return float(np.mean((p - y) ** 2)) - - -def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]: - bins, ici = _calibration_bins_and_ici( - p, y, n_bins=int(n_bins), return_bins=True) - return {"bins": bins, "ici": float(ici)} - - -def calibration_ici_only(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> float: - """Fast ICI only (no per-bin point export).""" - _, ici = _calibration_bins_and_ici( - p, y, n_bins=int(n_bins), return_bins=False) - return float(ici) - - -def _calibration_bins_and_ici( - p: np.ndarray, - y: np.ndarray, - *, - n_bins: int, - return_bins: bool, -) -> Tuple[List[Dict[str, Any]], float]: - """Vectorized quantile binning for calibration + ICI.""" - p = np.asarray(p, dtype=float) - y = np.asarray(y, dtype=float) - if p.size == 0: - return ([], float("nan")) if return_bins else ([], float("nan")) - - q = np.linspace(0.0, 1.0, int(n_bins) + 1) - edges = np.quantile(p, q) - edges = np.asarray(edges, dtype=float) - edges[0] = -np.inf - edges[-1] = np.inf - - # Bin assignment: i if edges[i] < p <= edges[i+1] - bin_idx = np.searchsorted(edges, p, side="right") - 1 - bin_idx = np.clip(bin_idx, 0, int(n_bins) - 1) - - counts = np.bincount(bin_idx, minlength=int(n_bins)).astype(float) - sum_p = np.bincount(bin_idx, weights=p, - minlength=int(n_bins)).astype(float) - sum_y = np.bincount(bin_idx, weights=y, - minlength=int(n_bins)).astype(float) - - nonempty = counts > 0 - if not np.any(nonempty): - return ([], float("nan")) if return_bins else ([], float("nan")) - - p_mean = np.zeros(int(n_bins), dtype=float) - y_mean = np.zeros(int(n_bins), dtype=float) - p_mean[nonempty] = sum_p[nonempty] / counts[nonempty] - y_mean[nonempty] = sum_y[nonempty] / counts[nonempty] - - diffs = np.abs(p_mean[nonempty] - y_mean[nonempty]) - ici = float(np.mean(diffs)) if diffs.size else float("nan") - - if not return_bins: - return [], ici - - bins: List[Dict[str, Any]] = [] - idxs = np.nonzero(nonempty)[0] - for i in idxs.tolist(): - bins.append( - { - "bin": int(i), - "p_mean": float(p_mean[i]), - "y_mean": float(y_mean[i]), - "n": int(counts[i]), - } - ) - return bins, ici - - -def _progress_line(done: int, total: int, prefix: str = "") -> str: - total_i = max(1, int(total)) - done_i = max(0, min(int(done), total_i)) - width = 28 - frac = done_i / total_i - filled = int(round(width * frac)) - bar = "#" * filled + "-" * (width - filled) - pct = 100.0 * frac - return f"{prefix}[{bar}] {done_i}/{total_i} ({pct:5.1f}%)" - - -def _should_show_progress(mode: str) -> bool: - m = str(mode).strip().lower() - if m in {"0", "false", "no", "none", "off"}: - return False - # Default: show if interactive. - if m in {"auto", "1", "true", "yes", "on", "bar"}: - try: - return bool(sys.stdout.isatty()) - except Exception: - return True - return True - - -def _safe_float(x: Any, default: float = float("nan")) -> float: - try: - return float(x) - except Exception: - return float(default) - - -def _ensure_dir(path: str) -> None: - os.makedirs(path, exist_ok=True) - - -def load_cause_names(path: str = "labels.csv") -> Dict[int, str]: - """Load 0-based cause_id -> name mapping. - - labels.csv is assumed to be one label per line, in disease-id order. - """ - if not os.path.exists(path): - return {} - mapping: Dict[int, str] = {} - with open(path, "r", encoding="utf-8") as f: - for i, line in enumerate(f): - name = line.strip() - if name: - mapping[int(i)] = name - return mapping - - -def pick_focus_causes( - *, - counts_within_tau: Optional[np.ndarray], - n_disease: int, - death_cause_id: int = DEFAULT_DEATH_CAUSE_ID, - k: int = 5, -) -> List[int]: - """Pick focus causes for user-facing evaluation. - - Rule: - 1) Always include death_cause_id first. - 2) Then add K additional causes by descending event count if available. - If counts_within_tau is None, fall back to descending cause_id coverage proxy. - - Notes: - - counts_within_tau is expected to be shape (n_disease,). - - Deterministic: ties broken by smaller cause id. - """ - n_disease_i = int(n_disease) - if death_cause_id < 0 or death_cause_id >= n_disease_i: - print( - f"WARNING: death_cause_id={death_cause_id} out of range (n_disease={n_disease_i}); " - "it will be omitted from focus causes." - ) - focus: List[int] = [] - else: - focus = [int(death_cause_id)] - - candidates = [i for i in range(n_disease_i) if i != int(death_cause_id)] - - if counts_within_tau is not None: - c = np.asarray(counts_within_tau).astype(float) - if c.shape[0] != n_disease_i: - print( - "WARNING: counts_within_tau length mismatch; falling back to coverage proxy ordering." - ) - counts_within_tau = None - else: - # Sort by (-count, cause_id) - order = sorted(candidates, key=lambda i: (-float(c[i]), int(i))) - order = [i for i in order if float(c[i]) > 0] - focus.extend([int(i) for i in order[: int(k)]]) - - if counts_within_tau is None: - # Fallback: deterministic coverage proxy (descending id, excluding death), then take K. - # (Real coverage requires data; this path is mostly for robustness.) - order = sorted(candidates, key=lambda i: (-int(i))) - focus.extend([int(i) for i in order[: int(k)]]) - - # De-dup while preserving order - seen = set() - out: List[int] = [] - for cid in focus: - if cid not in seen: - out.append(cid) - seen.add(cid) - return out - - -def write_simple_csv(path: str, fieldnames: List[str], rows: List[Dict[str, Any]]) -> None: - _ensure_dir(os.path.dirname(os.path.abspath(path)) or ".") - with open(path, "w", newline="", encoding="utf-8") as f: - w = csv.DictWriter(f, fieldnames=fieldnames) - w.writeheader() - for r in rows: - w.writerow(r) - - -def _sex_slices(sex: Optional[np.ndarray]) -> List[Tuple[str, Optional[np.ndarray]]]: - """Return list of (sex_label, mask) slices including an 'all' slice. - - If sex is missing, returns only ('all', None). - """ - out: List[Tuple[str, Optional[np.ndarray]]] = [("all", None)] - if sex is None: - return out - s = np.asarray(sex) - if s.ndim != 1: - return out - for val in [0, 1]: - m = (s == val) - if int(np.sum(m)) > 0: - out.append((str(val), m)) - return out - - -def _quantile_edges(p: np.ndarray, q: int) -> np.ndarray: - edges = np.quantile(p, np.linspace(0.0, 1.0, int(q) + 1)) - edges = np.asarray(edges, dtype=float) - edges[0] = -np.inf - edges[-1] = np.inf - return edges - - -def compute_risk_stratification_bins( - p: np.ndarray, - y: np.ndarray, - *, - q_default: int = 10, -) -> Tuple[int, List[Dict[str, Any]], Dict[str, Any]]: - """Compute quantile-based risk strata and a compact summary.""" - p = np.asarray(p, dtype=float) - y = np.asarray(y, dtype=float) - n = int(p.shape[0]) - if n == 0: - return 0, [], { - "y_overall": float("nan"), - "top_decile_y_rate": float("nan"), - "bottom_half_y_rate": float("nan"), - "lift_top10_vs_bottom50": float("nan"), - "slope_pred_vs_obs": float("nan"), - } - - # Choose quantiles robustly. - q = int(q_default) - if n < 200: - q = 5 - - edges = _quantile_edges(p, q) - y_overall = float(np.mean(y)) - bin_rows: List[Dict[str, Any]] = [] - p_means: List[float] = [] - y_rates: List[float] = [] - n_bins: List[int] = [] - - for i in range(q): - mask = (p > edges[i]) & (p <= edges[i + 1]) - nb = int(np.sum(mask)) - if nb == 0: - # Keep the row for consistent plotting; set NaNs. - bin_rows.append( - { - "q": int(i + 1), - "n_bin": 0, - "p_mean": float("nan"), - "y_rate": float("nan"), - "y_overall": y_overall, - "lift_vs_overall": float("nan"), - } - ) - continue - p_mean = float(np.mean(p[mask])) - y_rate = float(np.mean(y[mask])) - lift = float(y_rate / y_overall) if y_overall > 0 else float("nan") - bin_rows.append( - { - "q": int(i + 1), - "n_bin": nb, - "p_mean": p_mean, - "y_rate": y_rate, - "y_overall": y_overall, - "lift_vs_overall": lift, - } - ) - p_means.append(p_mean) - y_rates.append(y_rate) - n_bins.append(nb) - - # Summary - top_mask = (p > edges[q - 1]) & (p <= edges[q]) - bot_half_mask = (p > edges[0]) & (p <= edges[q // 2]) - top_y = float(np.mean(y[top_mask])) if int( - np.sum(top_mask)) > 0 else float("nan") - bot_y = float(np.mean(y[bot_half_mask])) if int( - np.sum(bot_half_mask)) > 0 else float("nan") - lift_top_vs_bottom = float(top_y / bot_y) if (np.isfinite(top_y) - and np.isfinite(bot_y) and bot_y > 0) else float("nan") - - slope = float("nan") - if len(p_means) >= 2: - # Weighted least squares slope of y_rate ~ p_mean. - x = np.asarray(p_means, dtype=float) - yy = np.asarray(y_rates, dtype=float) - w = np.asarray(n_bins, dtype=float) - xm = float(np.average(x, weights=w)) - ym = float(np.average(yy, weights=w)) - denom = float(np.sum(w * (x - xm) ** 2)) - if denom > 0: - slope = float(np.sum(w * (x - xm) * (yy - ym)) / denom) - - summary = { - "y_overall": y_overall, - "top_decile_y_rate": top_y, - "bottom_half_y_rate": bot_y, - "lift_top10_vs_bottom50": lift_top_vs_bottom, - "slope_pred_vs_obs": slope, - } - return q, bin_rows, summary - - -def compute_capture_points( - p: np.ndarray, - y: np.ndarray, - k_pcts: Sequence[int], -) -> List[Dict[str, Any]]: - p = np.asarray(p, dtype=float) - y = np.asarray(y, dtype=float) - n = int(p.shape[0]) - if n == 0: - return [] - order = np.argsort(-p) - y_sorted = y[order] - events_total = float(np.sum(y_sorted)) - - rows: List[Dict[str, Any]] = [] - for k in k_pcts: - kf = float(k) - n_targeted = int(math.ceil(n * kf / 100.0)) - n_targeted = max(1, min(n_targeted, n)) - events_targeted = float(np.sum(y_sorted[:n_targeted])) - capture = float(events_targeted / - events_total) if events_total > 0 else float("nan") - precision = float(events_targeted / float(n_targeted)) - rows.append( - { - "k_pct": int(k), - "n_targeted": int(n_targeted), - "events_targeted": float(events_targeted), - "events_total": float(events_total), - "event_capture_rate": capture, - "precision_in_targeted": precision, - } - ) - return rows - - -def compute_event_rate_at_topk_causes( - p_tau: np.ndarray, - y_tau: np.ndarray, - topk_list: Sequence[int], -) -> List[Dict[str, Any]]: - """Compute Event Rate@K for cross-cause prioritization. - - For each individual, rank causes by predicted risk p_tau at a fixed horizon. - For each K, select top-K causes and compute the fraction that occur within the horizon. - - Args: - p_tau: (N, K) predicted CIFs at a fixed horizon - y_tau: (N, K) binary labels (0/1) whether cause occurs within the horizon - topk_list: list of K values to evaluate - - Returns: - List of rows with: - - topk - - event_rate_mean / event_rate_median - - recall_mean / recall_median (averaged over individuals with >=1 true cause) - - n_total / n_valid_recall - """ - p = np.asarray(p_tau, dtype=float) - y = np.asarray(y_tau, dtype=float) - if p.ndim != 2 or y.ndim != 2 or p.shape != y.shape: - raise ValueError( - "compute_event_rate_at_topk_causes expects (N,K) arrays of equal shape") - - n, k_total = p.shape - if n == 0 or k_total == 0: - out: List[Dict[str, Any]] = [] - for kk in topk_list: - out.append( - { - "topk": int(max(1, int(kk))), - "event_rate_mean": float("nan"), - "event_rate_median": float("nan"), - "recall_mean": float("nan"), - "recall_median": float("nan"), - "n_total": int(n), - "n_valid_recall": 0, - } - ) - return out - - # Sanitize K list. - topks = sorted({int(x) for x in topk_list if int(x) > 0}) - if not topks: - return [] - - max_k = min(int(max(topks)), int(k_total)) - if max_k <= 0: - return [] - - # Efficient: get top max_k causes per individual, then sort within those. - part = np.argpartition(-p, kth=max_k - 1, axis=1)[:, :max_k] # (N, max_k) - p_part = np.take_along_axis(p, part, axis=1) - order = np.argsort(-p_part, axis=1) - top_sorted = np.take_along_axis(part, order, axis=1) # (N, max_k) - - out_rows: List[Dict[str, Any]] = [] - for kk in topks: - kk_eff = min(int(kk), int(k_total)) - idx = top_sorted[:, :kk_eff] - y_sel = np.take_along_axis(y, idx, axis=1) - # Selected true causes per person - hit = np.sum(y_sel, axis=1) - # Precision-like: fraction of selected causes that occur - per_person = hit / \ - float(kk_eff) if kk_eff > 0 else np.full((n,), np.nan) - - # Recall@K: fraction of true causes covered by top-K (undefined when no true cause) - g = np.sum(y, axis=1) - valid = g > 0 - recall = np.full((n,), np.nan, dtype=float) - recall[valid] = hit[valid] / g[valid] - out_rows.append( - { - "topk": int(kk_eff), - "event_rate_mean": float(np.mean(per_person)) if per_person.size else float("nan"), - "event_rate_median": float(np.median(per_person)) if per_person.size else float("nan"), - "recall_mean": float(np.nanmean(recall)) if int(np.sum(valid)) > 0 else float("nan"), - "recall_median": float(np.nanmedian(recall)) if int(np.sum(valid)) > 0 else float("nan"), - "n_total": int(n), - "n_valid_recall": int(np.sum(valid)), - } - ) - return out_rows - - -def compute_random_ranking_baseline_topk( - y_tau: np.ndarray, - topk_list: Sequence[int], - *, - z: float = 1.645, -) -> List[Dict[str, Any]]: - """Random ranking baseline for Event Rate@K and Recall@K. - - Baseline definition: - - For each individual, pick K causes uniformly at random without replacement. - - EventRate@K = (# selected causes that occur) / K. - - Recall@K = (# selected causes that occur) / (# causes that occur), averaged over individuals with >=1 true cause. - - This function computes the expected baseline mean and an approximate 5-95% range - for the population mean using a normal approximation of the hypergeometric variance. - - Args: - y_tau: (N, K_total) binary labels - topk_list: K values - z: z-score for the central interval; z=1.645 corresponds to ~90% (5-95%) - - Returns: - Rows with baseline means and p05/p95 for both metrics. - """ - y = np.asarray(y_tau, dtype=float) - if y.ndim != 2: - raise ValueError( - "compute_random_ranking_baseline_topk expects y_tau with shape (N,K)") - - n, k_total = y.shape - topks = sorted({int(x) for x in topk_list if int(x) > 0}) - if not topks: - return [] - - g = np.sum(y, axis=1) # (N,) - valid = g > 0 - n_valid = int(np.sum(valid)) - - out: List[Dict[str, Any]] = [] - for kk in topks: - kk_eff = min(int(kk), int(k_total)) if k_total > 0 else int(kk) - if n == 0 or k_total == 0 or kk_eff <= 0: - out.append( - { - "topk": int(max(1, kk_eff)), - "baseline_event_rate_mean": float("nan"), - "baseline_event_rate_p05": float("nan"), - "baseline_event_rate_p95": float("nan"), - "baseline_recall_mean": float("nan"), - "baseline_recall_p05": float("nan"), - "baseline_recall_p95": float("nan"), - "n_total": int(n), - "n_valid_recall": int(n_valid), - "k_total": int(k_total), - "baseline_method": "random_ranking_hypergeometric_normal_approx", - } - ) - continue - - # Expected EventRate@K per person is E[X]/K = (K * (g/K_total))/K = g/K_total. - er_mean = float(np.mean(g / float(k_total))) - - # Variance of hypergeometric count X: - # Var(X) = K * p * (1-p) * ((K_total - K)/(K_total - 1)), where p=g/K_total. - if k_total > 1 and kk_eff < k_total: - p = g / float(k_total) - finite_corr = (float(k_total - kk_eff) / float(k_total - 1)) - var_x = float(kk_eff) * p * (1.0 - p) * finite_corr - else: - var_x = np.zeros_like(g, dtype=float) - - var_er = var_x / (float(kk_eff) ** 2) - se_er_mean = float(np.sqrt(np.sum(var_er))) / float(max(1, n)) - er_p05 = float(np.clip(er_mean - z * se_er_mean, 0.0, 1.0)) - er_p95 = float(np.clip(er_mean + z * se_er_mean, 0.0, 1.0)) - - # Expected Recall@K for individuals with g>0 is K/K_total (clipped). - rec_mean = float(min(float(kk_eff) / float(k_total), 1.0)) - if n_valid > 0: - var_rec = np.zeros_like(g, dtype=float) - gv = g[valid] - var_xv = var_x[valid] - # Var( X / g ) = Var(X) / g^2 (approx; g is fixed per individual) - var_rec_v = var_xv / (gv ** 2) - se_rec_mean = float(np.sqrt(np.sum(var_rec_v))) / float(n_valid) - rec_p05 = float(np.clip(rec_mean - z * se_rec_mean, 0.0, 1.0)) - rec_p95 = float(np.clip(rec_mean + z * se_rec_mean, 0.0, 1.0)) - else: - rec_p05 = float("nan") - rec_p95 = float("nan") - - out.append( - { - "topk": int(kk_eff), - "baseline_event_rate_mean": er_mean, - "baseline_event_rate_p05": er_p05, - "baseline_event_rate_p95": er_p95, - "baseline_recall_mean": rec_mean, - "baseline_recall_p05": float(rec_p05), - "baseline_recall_p95": float(rec_p95), - "n_total": int(n), - "n_valid_recall": int(n_valid), - "k_total": int(k_total), - "baseline_method": "random_ranking_hypergeometric_normal_approx", - } - ) - return out - - -def count_occurs_within_horizon( - loader: DataLoader, - offset_years: float, - tau_years: float, - n_disease: int, - device: str, -) -> Tuple[np.ndarray, int]: - """Count per-person occurrence within tau after the prediction context. - - Returns counts[k] = number of individuals with disease k at least once in (t_ctx, t_ctx+tau]. - """ - counts = torch.zeros((n_disease,), dtype=torch.long, device=device) - n_total_eval = 0 - for batch in loader: - event_seq, time_seq, cont_feats, cate_feats, sexes = batch - event_seq = event_seq.to(device) - time_seq = time_seq.to(device) - - keep, t_ctx, _ = select_context_indices( - event_seq, time_seq, offset_years) - if not keep.any(): - continue - - n_total_eval += int(keep.sum().item()) - event_seq = event_seq[keep] - time_seq = time_seq[keep] - t_ctx = t_ctx[keep] - - B, L = event_seq.shape - b = torch.arange(B, device=device) - t0 = time_seq[b, t_ctx] - t1 = t0 + (float(tau_years) * DAYS_PER_YEAR) - - idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) - in_window = ( - (idxs > t_ctx.unsqueeze(1)) - & (time_seq >= t0.unsqueeze(1)) - & (time_seq <= t1.unsqueeze(1)) - & (event_seq >= 2) - & (event_seq != 0) - ) - if not in_window.any(): - continue - - b_idx, t_idx = in_window.nonzero(as_tuple=True) - disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) - - # unique per (person, disease) to count per-person within-window occurrence - key = b_idx.to(torch.long) * int(n_disease) + disease_ids - uniq = torch.unique(key) - uniq_disease = uniq % int(n_disease) - counts.scatter_add_(0, uniq_disease, torch.ones_like( - uniq_disease, dtype=torch.long)) - - return counts.detach().cpu().numpy(), int(n_total_eval) - - -# ============================================================ -# Evaluation core -# ============================================================ - -def instantiate_model_and_head( - cfg: Dict[str, Any], - dataset: HealthDataset, - device: str, - checkpoint_path: str = "", -) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]: - model_type = str(cfg["model_type"]) - loss_type = str(cfg["loss_type"]) - - bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) - if loss_type == "exponential": - out_dims = [dataset.n_disease] - elif loss_type == "discrete_time_cif": - out_dims = [dataset.n_disease + 1, len(bin_edges)] - elif loss_type == "pwe_cif": - # Match training: drop +inf if present and evaluate up to the last finite edge. - pwe_edges = [float(x) for x in bin_edges if math.isfinite(float(x))] - if len(pwe_edges) < 2: - raise ValueError( - f"pwe_cif requires >=2 finite edges; got bin_edges={list(bin_edges)}" - ) - n_bins = len(pwe_edges) - 1 - out_dims = [dataset.n_disease, n_bins] - bin_edges = pwe_edges - else: - raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}") - - if model_type == "delphi_fork": - backbone = DelphiFork( - n_disease=dataset.n_disease, - n_tech_tokens=2, - n_embd=int(cfg["n_embd"]), - n_head=int(cfg["n_head"]), - n_layer=int(cfg["n_layer"]), - pdrop=float(cfg.get("pdrop", 0.0)), - age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), - n_cont=dataset.n_cont, - n_cate=dataset.n_cate, - cate_dims=dataset.cate_dims, - ).to(device) - elif model_type == "sap_delphi": - # Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path. - emb_path = cfg.get("pretrained_emb_path", None) - if emb_path in {"", None}: - emb_path = cfg.get("pretrained_emd_path", None) - if emb_path in {"", None}: - run_dir = os.path.dirname(os.path.abspath( - checkpoint_path)) if checkpoint_path else "" - print( - f"WARNING: SapDelphi pretrained embedding path missing in config " - f"(expected 'pretrained_emb_path' or 'pretrained_emd_path'). " - f"checkpoint={checkpoint_path} run_dir={run_dir}" - ) - backbone = SapDelphi( - n_disease=dataset.n_disease, - n_tech_tokens=2, - n_embd=int(cfg["n_embd"]), - n_head=int(cfg["n_head"]), - n_layer=int(cfg["n_layer"]), - pdrop=float(cfg.get("pdrop", 0.0)), - age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), - n_cont=dataset.n_cont, - n_cate=dataset.n_cate, - cate_dims=dataset.cate_dims, - pretrained_weights_path=emb_path, - freeze_embeddings=True, - ).to(device) - else: - raise ValueError(f"Unsupported model_type: {model_type}") - - head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device) - return backbone, head, loss_type, bin_edges - - -@torch.no_grad() -def predict_cifs_for_model( - backbone: torch.nn.Module, - head: torch.nn.Module, - loss_type: str, - bin_edges: Sequence[float], - loader: DataLoader, - device: str, - offset_years: float, - eval_horizons: Sequence[float], - n_disease: int, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Run model and produce cause-specific, time-dependent CIF outputs. - - Returns: - cif_full: (N, K, H) - survival: (N, H) - y_cause_within_tau: (N, K, H) - - NOTE: Evaluation is cause-specific and horizon-specific (multi-disease risk). - """ - backbone.eval() - head.eval() - - # We will accumulate in CPU lists, then concat. - cif_full_list: List[np.ndarray] = [] - survival_list: List[np.ndarray] = [] - y_cause_within_list: List[np.ndarray] = [] - sex_list: List[np.ndarray] = [] - - for batch in loader: - event_seq, time_seq, cont_feats, cate_feats, sexes = batch - event_seq = event_seq.to(device) - time_seq = time_seq.to(device) - cont_feats = cont_feats.to(device) - cate_feats = cate_feats.to(device) - sexes = sexes.to(device) - - keep, t_ctx, _ = select_context_indices( - event_seq, time_seq, offset_years) - if not keep.any(): - continue - - # filter batch - event_seq = event_seq[keep] - time_seq = time_seq[keep] - cont_feats = cont_feats[keep] - cate_feats = cate_feats[keep] - sexes_k = sexes[keep] - t_ctx = t_ctx[keep] - - h = backbone(event_seq, time_seq, sexes_k, - cont_feats, cate_feats) # (B,L,D) - b = torch.arange(h.size(0), device=device) - c = h[b, t_ctx] # (B,D) - logits = head(c) - - if loss_type == "exponential": - cif_full, survival = cifs_from_exponential_logits( - logits, eval_horizons, return_survival=True) # (B,K,H), (B,H) - elif loss_type == "discrete_time_cif": - cif_full, survival = cifs_from_discrete_time_logits( - # (B,K,H), (B,H) - logits, bin_edges, eval_horizons, return_survival=True) - elif loss_type == "pwe_cif": - cif_full, survival = cifs_from_pwe_logits( - logits, bin_edges, eval_horizons, return_survival=True) - else: - raise ValueError(f"Unsupported loss_type: {loss_type}") - - # Within-horizon labels for all causes: disease k occurs within tau after context. - y_within_full = torch.stack( - [ - multi_hot_ever_within_horizon( - event_seq=event_seq, - time_seq=time_seq, - t_ctx=t_ctx, - tau_years=float(tau), - n_disease=int(n_disease), - ).to(torch.float32) - for tau in eval_horizons - ], - dim=2, - ) # (B,K,H) - - cif_full_list.append(cif_full.detach().cpu().numpy()) - survival_list.append(survival.detach().cpu().numpy()) - y_cause_within_list.append(y_within_full.detach().cpu().numpy()) - sex_list.append(sexes_k.detach().cpu().numpy()) - - if not cif_full_list: - raise RuntimeError( - "No valid samples for evaluation (all batches filtered out by offset).") - - cif_full = np.concatenate(cif_full_list, axis=0) - survival = np.concatenate(survival_list, axis=0) - y_cause_within = np.concatenate(y_cause_within_list, axis=0) - sex = np.concatenate( - sex_list, axis=0) if sex_list else np.array([], dtype=int) - - return cif_full, survival, y_cause_within, sex - - -def evaluate_one_model( - model_name: str, - cif_full: np.ndarray, - y_cause_within_tau: np.ndarray, - eval_horizons: Sequence[float], - out_rows: List[Dict[str, Any]], - calib_rows: List[Dict[str, Any]], - calib_cause_ids: Optional[Sequence[int]], - n_calib_bins: int = 10, - metric_workers: int = 0, - progress: str = "auto", -) -> None: - """Compute per-cause metrics for ALL diseases. - - Notes: - - Writes scalar metrics for all causes into out_rows. - - Writes calibration-bin points only for calib_cause_ids (to keep outputs tractable). - """ - cif_full = np.asarray(cif_full, dtype=float) - y_cause_within_tau = np.asarray(y_cause_within_tau, dtype=float) - if cif_full.ndim != 3 or y_cause_within_tau.ndim != 3: - raise ValueError( - "Expected cif_full and y_cause_within_tau with shape (N, K, H)") - if cif_full.shape != y_cause_within_tau.shape: - raise ValueError( - f"Shape mismatch: cif_full {cif_full.shape} vs y_cause_within_tau {y_cause_within_tau.shape}" - ) - - N, K, H = cif_full.shape - if H != len(eval_horizons): - raise ValueError("H mismatch between cif_full and eval_horizons") - - calib_set = set(int(x) - for x in calib_cause_ids) if calib_cause_ids is not None else set() - - workers = int(metric_workers) - if workers <= 0: - workers = int(min(8, os.cpu_count() or 1)) - workers = max(1, workers) - show_progress = _should_show_progress(progress) - - def _eval_chunk( - *, - tau: float, - p_tau: np.ndarray, - y_tau: np.ndarray, - brier_by_cause: np.ndarray, - cause_ids: np.ndarray, - ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], int]: - local_rows: List[Dict[str, Any]] = [] - local_calib: List[Dict[str, Any]] = [] - for cid in cause_ids.tolist(): - p = p_tau[:, cid] - y = y_tau[:, cid] - - local_rows.append( - { - "model_name": model_name, - "metric_name": "cause_brier", - "horizon": float(tau), - "cause": int(cid), - "value": float(brier_by_cause[cid]), - "ci_low": "", - "ci_high": "", - } - ) - - # ICI: compute bins only if we will export them. - need_bins = (not calib_set) or (int(cid) in calib_set) - if need_bins: - cal = calibration_deciles(p, y, n_bins=n_calib_bins) - ici = float(cal["ici"]) - else: - cal = None - ici = calibration_ici_only(p, y, n_bins=n_calib_bins) - - local_rows.append( - { - "model_name": model_name, - "metric_name": "cause_ici", - "horizon": float(tau), - "cause": int(cid), - "value": float(ici), - "ci_low": "", - "ci_high": "", - } - ) - - # Secondary: discrimination via AUC at the same horizon (point estimate only). - auc = roc_auc_rank(y, p) - - local_rows.append( - { - "model_name": model_name, - "metric_name": "cause_auc", - "horizon": float(tau), - "cause": int(cid), - "value": float(auc), - "ci_low": "", - "ci_high": "", - } - ) - - if need_bins and cal is not None: - for binfo in cal.get("bins", []): - local_calib.append( - { - "model_name": model_name, - "task": "cause_k", - "horizon": float(tau), - "cause_id": int(cid), - "bin_index": int(binfo["bin"]), - "p_mean": float(binfo["p_mean"]), - "y_mean": float(binfo["y_mean"]), - "n_in_bin": int(binfo["n"]), - } - ) - return local_rows, local_calib, int(cause_ids.size) - - # Cause-specific, time-dependent metrics per horizon. - for h_i, tau in enumerate(eval_horizons): - p_tau = cif_full[:, :, h_i] # (N, K) - y_tau = y_cause_within_tau[:, :, h_i] # (N, K) - - # Vectorized Brier for speed. - brier_by_cause = np.mean((p_tau - y_tau) ** 2, axis=0) # (K,) - - # Parallelize disease-level metrics; chunk to avoid millions of futures. - all_ids = np.arange(int(K), dtype=int) - chunks = np.array_split(all_ids, workers) - done = 0 - prefix = f"[{model_name}] tau={float(tau)}y " - t0 = time.time() - - if workers <= 1: - for ch in chunks: - r_chunk, c_chunk, n_done = _eval_chunk( - tau=float(tau), - p_tau=p_tau, - y_tau=y_tau, - brier_by_cause=brier_by_cause, - cause_ids=ch, - ) - out_rows.extend(r_chunk) - calib_rows.extend(c_chunk) - done += int(n_done) - if show_progress: - sys.stdout.write( - "\r" + _progress_line(done, int(K), prefix=prefix)) - sys.stdout.flush() - else: - with ThreadPoolExecutor(max_workers=workers) as ex: - futs = [ - ex.submit( - _eval_chunk, - tau=float(tau), - p_tau=p_tau, - y_tau=y_tau, - brier_by_cause=brier_by_cause, - cause_ids=ch, - ) - for ch in chunks - if int(ch.size) > 0 - ] - for fut in as_completed(futs): - r_chunk, c_chunk, n_done = fut.result() - out_rows.extend(r_chunk) - calib_rows.extend(c_chunk) - done += int(n_done) - if show_progress: - sys.stdout.write( - "\r" + _progress_line(done, int(K), prefix=prefix)) - sys.stdout.flush() - - if show_progress: - dt = time.time() - t0 - sys.stdout.write("\r" + _progress_line(int(K), - int(K), prefix=prefix) + f" ({dt:.1f}s)\n") - sys.stdout.flush() - - -def summarize_over_diseases( - rows: List[Dict[str, Any]], - *, - model_name: str, - eval_horizons: Sequence[float], - metrics: Sequence[str] = ("cause_brier", "cause_ici", "cause_auc"), -) -> List[Dict[str, Any]]: - """Summarize mean/median of each metric over diseases (per horizon).""" - out: List[Dict[str, Any]] = [] - # Build metric_name -> horizon -> list of values - bucket: Dict[Tuple[str, float], List[float]] = {} - for r in rows: - if r.get("model_name") != model_name: - continue - m = str(r.get("metric_name")) - if m not in set(metrics): - continue - h = _safe_float(r.get("horizon")) - v = _safe_float(r.get("value")) - if not np.isfinite(h): - continue - if not np.isfinite(v): - continue - bucket.setdefault((m, float(h)), []).append(float(v)) - - for tau in eval_horizons: - ht = float(tau) - for m in metrics: - vals = bucket.get((str(m), ht), []) - if vals: - arr = np.asarray(vals, dtype=float) - mean_v = float(np.mean(arr)) - med_v = float(np.median(arr)) - n_valid = int(arr.size) - else: - mean_v = float("nan") - med_v = float("nan") - n_valid = 0 - out.append( - { - "model_name": str(model_name), - "metric_name": str(m), - "horizon": ht, - "mean": mean_v, - "median": med_v, - "n_valid": n_valid, - } - ) - return out - - -def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None: - fieldnames = [ - "model_name", - "task", - "horizon", - "cause_id", - "bin_index", - "p_mean", - "y_mean", - "n_in_bin", - ] - with open(path, "w", newline="") as f: - w = csv.DictWriter(f, fieldnames=fieldnames) - w.writeheader() - for r in rows: - w.writerow(r) - - -def write_results_csv(path: str, rows: List[Dict[str, Any]]) -> None: - fieldnames = [ - "model_name", - "metric_name", - "horizon", - "cause", - "value", - "ci_low", - "ci_high", - ] - with open(path, "w", newline="") as f: - w = csv.DictWriter(f, fieldnames=fieldnames) - w.writeheader() - for r in rows: - w.writerow(r) - - -def _make_eval_tag(split: str, offset_years: float) -> str: - """Short tag for filenames written into run directories.""" - off = f"{float(offset_years):.4f}".rstrip("0").rstrip(".") - return f"{split}_offset{off}y" - - -def main() -> int: - ap = argparse.ArgumentParser( - description="Unified downstream evaluation via CIFs") - ap.add_argument("--models_json", type=str, required=True, - help="Path to models list JSON") - ap.add_argument("--data_prefix", type=str, - default="ukb", help="Dataset prefix") - ap.add_argument("--split", type=str, default="test", - choices=["train", "val", "test", "all"], help="Which split to evaluate") - ap.add_argument("--offset_years", type=float, default=0.5, - help="Anti-leakage offset (years)") - ap.add_argument("--eval_horizons", type=float, - nargs="*", default=DEFAULT_EVAL_HORIZONS) - ap.add_argument("--batch_size", type=int, default=128) - ap.add_argument("--num_workers", type=int, default=0) - ap.add_argument("--seed", type=int, default=123) - ap.add_argument("--device", type=str, - default="cuda" if torch.cuda.is_available() else "cpu") - ap.add_argument("--out_csv", type=str, default="eval_summary.csv") - ap.add_argument("--out_meta_json", type=str, default="eval_meta.json") - - # Integrity checks - ap.add_argument("--integrity_strict", action="store_true", default=False) - ap.add_argument("--integrity_tol", type=float, default=1e-6) - - # Speed/UX - ap.add_argument( - "--metric_workers", - type=int, - default=0, - help="Threads for per-disease metrics (0=auto, 1=disable parallelism)", - ) - ap.add_argument( - "--progress", - type=str, - default="auto", - choices=["auto", "bar", "none"], - help="Progress visualization during per-disease evaluation", - ) - - # Export settings for user-facing experiments - ap.add_argument("--export_dir", type=str, default="eval_exports") - ap.add_argument("--death_cause_id", type=int, - default=DEFAULT_DEATH_CAUSE_ID) - ap.add_argument("--focus_k", type=int, default=5, - help="Additional non-death causes to include") - ap.add_argument("--capture_k_pcts", type=int, - nargs="*", default=[1, 5, 10, 20]) - ap.add_argument( - "--capture_curve_max_pct", - type=int, - default=50, - help="If >0, also export a dense capture curve for k=1..max_pct", - ) - - # High-risk cause concentration (cross-cause prioritization) - ap.add_argument( - "--cause_concentration_topk", - type=int, - nargs="*", - default=[5, 10, 20, 50], - help="Top-K causes per individual for Event Rate@K (cross-cause prioritization)", - ) - ap.add_argument( - "--cause_concentration_write_random_baseline", - action="store_true", - default=False, - help="If set, also export a random-ranking baseline (expected Event Rate@K and Recall@K with an uncertainty range)", - ) - args = ap.parse_args() - - set_deterministic(args.seed) - - specs = load_models_json(args.models_json) - if not specs: - raise ValueError("No models provided") - - export_dir = str(args.export_dir) - _ensure_dir(export_dir) - cause_names = load_cause_names("labels.csv") - - # Determine top-K causes from the evaluation split only (model-agnostic), - # aligned to time-dependent risk: occurrence within tau_max after context. - first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path) - cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [ - "bmi", "smoking", "alcohol"] - dataset_for_top = HealthDataset( - data_prefix=args.data_prefix, covariate_list=cov_list) - subset_for_top = build_eval_subset( - dataset_for_top, - train_ratio=float(first_cfg.get("train_ratio", 0.7)), - val_ratio=float(first_cfg.get("val_ratio", 0.15)), - seed=int(first_cfg.get("random_seed", 42)), - split=args.split, - ) - loader_top = DataLoader( - subset_for_top, - batch_size=args.batch_size, - shuffle=False, - num_workers=args.num_workers, - collate_fn=health_collate_fn, - ) - - tau_max = float(max(args.eval_horizons)) - counts, n_total_eval = count_occurs_within_horizon( - loader=loader_top, - offset_years=args.offset_years, - tau_years=tau_max, - n_disease=dataset_for_top.n_disease, - device=args.device, - ) - - focus_causes = pick_focus_causes( - counts_within_tau=counts, - n_disease=int(dataset_for_top.n_disease), - death_cause_id=int(args.death_cause_id), - k=int(args.focus_k), - ) - top_cause_ids = np.asarray(focus_causes, dtype=int) - - # Export the chosen focus causes. - focus_rows: List[Dict[str, Any]] = [] - for r, cid in enumerate(focus_causes, start=1): - row: Dict[str, Any] = {"cause": int(cid), "rank": int(r)} - if cid in cause_names: - row["cause_name"] = cause_names[cid] - focus_rows.append(row) - focus_fieldnames = ["cause", "rank"] + \ - (["cause_name"] if any("cause_name" in r for r in focus_rows) else []) - write_simple_csv(os.path.join(export_dir, "focus_causes.csv"), - focus_fieldnames, focus_rows) - - # Metadata for focus causes (within tau_max). - top_causes_meta: List[Dict[str, Any]] = [] - for cid in focus_causes: - n_case = int(counts[int(cid)]) if int( - cid) < int(counts.shape[0]) else 0 - top_causes_meta.append( - { - "cause_id": int(cid), - "tau_years": float(tau_max), - "n_case_within_tau": n_case, - "n_control_within_tau": int(n_total_eval - n_case), - "n_total_eval": int(n_total_eval), - } - ) - - summary_rows: List[Dict[str, Any]] = [] - calib_rows: List[Dict[str, Any]] = [] - - # Experiment exports (accumulated across models) - cap_points_rows: List[Dict[str, Any]] = [] - cap_curve_rows: List[Dict[str, Any]] = [] - conc_rows: List[Dict[str, Any]] = [] - conc_base_rows: List[Dict[str, Any]] = [] - - # Track per-model integrity status for meta JSON. - integrity_meta: Dict[str, Any] = {} - - # Evaluate each model - for spec in specs: - run_dir = os.path.dirname(os.path.abspath(spec.checkpoint_path)) - tag = _make_eval_tag(args.split, float(args.offset_years)) - - # Remember list offsets so we can write per-model slices to the model's run_dir. - calib_start = len(calib_rows) - - cfg = load_train_config_for_checkpoint(spec.checkpoint_path) - - # Identifiers for consistent exports - model_id = str(spec.name) - model_type = str( - cfg.get("model_type", spec.model_type if hasattr(spec, "model_type") else "")) - loss_type_id = str( - cfg.get("loss_type", spec.loss_type if hasattr(spec, "loss_type") else "")) - age_encoder = str(cfg.get("age_encoder", "")) - cov_type = "full" if _parse_bool( - cfg.get("full_cov", False)) else "partial" - - cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [ - "bmi", "smoking", "alcohol"] - dataset = HealthDataset( - data_prefix=args.data_prefix, covariate_list=cov_list) - subset = build_eval_subset( - dataset, - train_ratio=float(cfg.get("train_ratio", 0.7)), - val_ratio=float(cfg.get("val_ratio", 0.15)), - seed=int(cfg.get("random_seed", 42)), - split=args.split, - ) - loader = DataLoader( - subset, - batch_size=args.batch_size, - shuffle=False, - num_workers=args.num_workers, - collate_fn=health_collate_fn, - ) - - backbone, head, loss_type, bin_edges = instantiate_model_and_head( - cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path) - ckpt = torch.load(spec.checkpoint_path, map_location=args.device) - backbone.load_state_dict(ckpt["model_state_dict"], strict=True) - head.load_state_dict(ckpt["head_state_dict"], strict=True) - - ( - cif_full, - survival, - y_cause_within_tau, - sex, - ) = predict_cifs_for_model( - backbone, - head, - loss_type, - bin_edges, - loader, - args.device, - args.offset_years, - args.eval_horizons, - n_disease=int(dataset.n_disease), - ) - - # CIF integrity checks before metrics. - integrity_ok, integrity_notes = check_cif_integrity( - cif_full, - args.eval_horizons, - tol=float(args.integrity_tol), - name=spec.name, - strict=bool(args.integrity_strict), - survival=survival, - ) - integrity_meta[spec.name] = { - "integrity_ok": bool(integrity_ok), - "integrity_notes": integrity_notes, - } - - # Per-disease metrics for ALL diseases (written into the model's run_dir). - model_rows: List[Dict[str, Any]] = [] - evaluate_one_model( - model_name=spec.name, - cif_full=cif_full, - y_cause_within_tau=y_cause_within_tau, - eval_horizons=args.eval_horizons, - out_rows=model_rows, - calib_rows=calib_rows, - calib_cause_ids=top_cause_ids.tolist(), - metric_workers=int(args.metric_workers), - progress=str(args.progress), - ) - - # Summary over diseases (mean/median per horizon). - model_summary_rows = summarize_over_diseases( - model_rows, - model_name=spec.name, - eval_horizons=args.eval_horizons, - ) - summary_rows.extend(model_summary_rows) - - # ============================================================ - # Experiment: High-Risk Cause Concentration at fixed horizon - # (cross-cause prioritization accuracy) - # ============================================================ - topk_causes = [int(x) for x in args.cause_concentration_topk] - for sex_label, sex_mask in _sex_slices(sex if sex.size else None): - for h_i, tau in enumerate(args.eval_horizons): - p_tau_all = np.asarray(cif_full[:, :, h_i], dtype=float) - y_tau_all = np.asarray( - y_cause_within_tau[:, :, h_i], dtype=float) - if sex_mask is not None: - p_tau_all = p_tau_all[sex_mask] - y_tau_all = y_tau_all[sex_mask] - for rr in compute_event_rate_at_topk_causes(p_tau_all, y_tau_all, topk_causes): - conc_rows.append( - { - "model_id": model_id, - "model_type": model_type, - "loss_type": loss_type_id, - "age_encoder": age_encoder, - "cov_type": cov_type, - "horizon": float(tau), - "sex": sex_label, - **rr, - } - ) - - if bool(args.cause_concentration_write_random_baseline): - for rr in compute_random_ranking_baseline_topk(y_tau_all, topk_causes): - conc_base_rows.append( - { - "model_id": model_id, - "model_type": model_type, - "loss_type": loss_type_id, - "age_encoder": age_encoder, - "cov_type": cov_type, - "horizon": float(tau), - "sex": sex_label, - **rr, - } - ) - - # Convenience slices for user-facing experiments (focus causes only). - cause_cif_focus = cif_full[:, top_cause_ids, :] - y_within_focus = y_cause_within_tau[:, top_cause_ids, :] - - # ============================================================ - # Experiment 2: High-risk capture points (+ optional curve) - # ============================================================ - k_pcts = [int(x) for x in args.capture_k_pcts] - curve_max = int(args.capture_curve_max_pct) - curve_grid = list(range(1, curve_max + 1) - ) if curve_max and curve_max > 0 else [] - for sex_label, sex_mask in _sex_slices(sex if sex.size else None): - for h_i, tau in enumerate(args.eval_horizons): - for j, cause_id in enumerate(top_cause_ids.tolist()): - p = cause_cif_focus[:, j, h_i] - y = y_within_focus[:, j, h_i] - if sex_mask is not None: - p = p[sex_mask] - y = y[sex_mask] - - for r in compute_capture_points(p, y, k_pcts): - cap_points_rows.append( - { - "model_id": model_id, - "model_type": model_type, - "loss_type": loss_type_id, - "age_encoder": age_encoder, - "cov_type": cov_type, - "cause": int(cause_id), - "horizon": float(tau), - "sex": sex_label, - **r, - } - ) - if curve_grid: - for r in compute_capture_points(p, y, curve_grid): - cap_curve_rows.append( - { - "model_id": model_id, - "model_type": model_type, - "loss_type": loss_type_id, - "age_encoder": age_encoder, - "cov_type": cov_type, - "cause": int(cause_id), - "horizon": float(tau), - "sex": sex_label, - **r, - } - ) - - # Optionally write top-cause counts into the main results CSV as metric rows. - for tc in top_causes_meta: - model_rows.append( - { - "model_name": spec.name, - "metric_name": "topcause_n_case_within_tau", - "horizon": float(tc["tau_years"]), - "cause": int(tc["cause_id"]), - "value": int(tc["n_case_within_tau"]), - "ci_low": "", - "ci_high": "", - } - ) - model_rows.append( - { - "model_name": spec.name, - "metric_name": "topcause_n_control_within_tau", - "horizon": float(tc["tau_years"]), - "cause": int(tc["cause_id"]), - "value": int(tc["n_control_within_tau"]), - "ci_low": "", - "ci_high": "", - } - ) - model_rows.append( - { - "model_name": spec.name, - "metric_name": "topcause_n_total_eval", - "horizon": float(tc["tau_years"]), - "cause": int(tc["cause_id"]), - "value": int(tc["n_total_eval"]), - "ci_low": "", - "ci_high": "", - } - ) - - # Write per-model results into the model's run directory. - model_calib_rows = calib_rows[calib_start:] - model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv") - model_summary_csv = os.path.join(run_dir, f"eval_summary_{tag}.csv") - model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv") - model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json") - - write_results_csv(model_out_csv, model_rows) - write_simple_csv( - model_summary_csv, - ["model_name", "metric_name", "horizon", "mean", "median", "n_valid"], - model_summary_rows, - ) - write_calibration_bins_csv(model_calib_csv, model_calib_rows) - - model_meta = { - "model_name": spec.name, - "checkpoint_path": spec.checkpoint_path, - "run_dir": run_dir, - "split": args.split, - "offset_years": args.offset_years, - "eval_horizons": [float(x) for x in args.eval_horizons], - "n_disease": int(dataset.n_disease), - "top_cause_ids": top_cause_ids.tolist(), - "top_causes": top_causes_meta, - "integrity": {spec.name: integrity_meta.get(spec.name, {})}, - "paths": { - "results_csv": model_out_csv, - "summary_csv": model_summary_csv, - "calibration_bins_csv": model_calib_csv, - }, - } - with open(model_meta_json, "w") as f: - json.dump(model_meta, f, indent=2) - - print(f"Wrote per-model results to {model_out_csv}") - - # Write global summary (across diseases) across all models. - write_simple_csv( - args.out_csv, - ["model_name", "metric_name", "horizon", "mean", "median", "n_valid"], - summary_rows, - ) - - # Write calibration curve points to a separate CSV. - out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "." - calib_csv_path = os.path.join(out_dir, "calibration_bins.csv") - write_calibration_bins_csv(calib_csv_path, calib_rows) - - # Write experiment exports - write_simple_csv( - os.path.join(export_dir, "lift_capture_points.csv"), - [ - "model_id", - "model_type", - "loss_type", - "age_encoder", - "cov_type", - "cause", - "horizon", - "sex", - "k_pct", - "n_targeted", - "events_targeted", - "events_total", - "event_capture_rate", - "precision_in_targeted", - ], - cap_points_rows, - ) - if cap_curve_rows: - write_simple_csv( - os.path.join(export_dir, "lift_capture_curve.csv"), - [ - "model_id", - "model_type", - "loss_type", - "age_encoder", - "cov_type", - "cause", - "horizon", - "sex", - "k_pct", - "n_targeted", - "events_targeted", - "events_total", - "event_capture_rate", - "precision_in_targeted", - ], - cap_curve_rows, - ) - - write_simple_csv( - os.path.join(export_dir, "high_risk_cause_concentration.csv"), - [ - "model_id", - "model_type", - "loss_type", - "age_encoder", - "cov_type", - "horizon", - "sex", - "topk", - "event_rate_mean", - "event_rate_median", - "recall_mean", - "recall_median", - "n_total", - "n_valid_recall", - ], - conc_rows, - ) - - if conc_base_rows: - write_simple_csv( - os.path.join( - export_dir, "high_risk_cause_concentration_random_baseline.csv"), - [ - "model_id", - "model_type", - "loss_type", - "age_encoder", - "cov_type", - "horizon", - "sex", - "topk", - "baseline_event_rate_mean", - "baseline_event_rate_p05", - "baseline_event_rate_p95", - "baseline_recall_mean", - "baseline_recall_p05", - "baseline_recall_p95", - "n_total", - "n_valid_recall", - "k_total", - "baseline_method", - ], - conc_base_rows, - ) - - # Manifest markdown (stable, user-facing) - manifest_path = os.path.join(export_dir, "eval_exports_manifest.md") - with open(manifest_path, "w", encoding="utf-8") as f: - f.write( - "# Evaluation Exports Manifest\n\n" - "This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). " - "All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n" - "## Files\n\n" - "- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n" - "- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n" - "- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n" - "- high_risk_cause_concentration.csv: Event Rate@K and Recall@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart vs K.\n" - "- high_risk_cause_concentration_random_baseline.csv (optional): Random-ranking baseline for Event Rate@K and Recall@K with an uncertainty range (enabled by --cause_concentration_write_random_baseline).\n" - ) - - meta = { - "split": args.split, - "offset_years": args.offset_years, - "eval_horizons": [float(x) for x in args.eval_horizons], - "tau_max": float(tau_max), - "n_disease": int(dataset_for_top.n_disease), - "top_cause_ids": top_cause_ids.tolist(), - "top_causes": top_causes_meta, - "integrity": integrity_meta, - "notes": { - "label": "Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau])", - "primary_metrics": "cause_brier (CIF-based) and cause_ici (calibration)", - "secondary_metrics": "cause_auc (discrimination)", - "exclusions": "No all-cause aggregation; no next-event formulation; ECE not reported", - "warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.", - "exports_dir": export_dir, - "focus_causes": focus_causes, - }, - } - with open(args.out_meta_json, "w") as f: - json.dump(meta, f, indent=2) - - print(f"Wrote {args.out_csv} with {len(summary_rows)} rows") - print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows") - print(f"Wrote {args.out_meta_json}") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/models_eval_example.json b/models_eval_example.json deleted file mode 100644 index f2054e4..0000000 --- a/models_eval_example.json +++ /dev/null @@ -1,114 +0,0 @@ -[ - { - "name": "delphi_fork_discrete_time_cif_mlp_fullcov", - "model_type": "delphi_fork", - "loss_type": "discrete_time_cif", - "full_cov": true, - "checkpoint_path": "runs/delphi_fork_discrete_time_cif_mlp_fullcov_20260110-021823/best_model.pt" - }, - { - "name": "delphi_fork_discrete_time_cif_mlp_partcov", - "model_type": "delphi_fork", - "loss_type": "discrete_time_cif", - "full_cov": false, - "checkpoint_path": "runs/delphi_fork_discrete_time_cif_mlp_partcov_20260110-013741/best_model.pt" - }, - { - "name": "delphi_fork_discrete_time_cif_sinusoidal_fullcov", - "model_type": "delphi_fork", - "loss_type": "discrete_time_cif", - "full_cov": true, - "checkpoint_path": "runs/delphi_fork_discrete_time_cif_sinusoidal_fullcov_20260109-222502/best_model.pt" - }, - { - "name": "delphi_fork_discrete_time_cif_sinusoidal_partcov", - "model_type": "delphi_fork", - "loss_type": "discrete_time_cif", - "full_cov": false, - "checkpoint_path": "runs/delphi_fork_discrete_time_cif_sinusoidal_partcov_20260109-222502/best_model.pt" - }, - { - "name": "delphi_fork_exponential_mlp_fullcov", - "model_type": "delphi_fork", - "loss_type": "exponential", - "full_cov": true, - "checkpoint_path": "runs/delphi_fork_exponential_mlp_fullcov_20260110-042001/best_model.pt" - }, - { - "name": "delphi_fork_exponential_mlp_partcov", - "model_type": "delphi_fork", - "loss_type": "exponential", - "full_cov": false, - "checkpoint_path": "runs/delphi_fork_exponential_mlp_partcov_20260110-040737/best_model.pt" - }, - { - "name": "delphi_fork_exponential_sinusoidal_fullcov", - "model_type": "delphi_fork", - "loss_type": "exponential", - "full_cov": true, - "checkpoint_path": "runs/delphi_fork_exponential_sinusoidal_fullcov_20260109-222502/best_model.pt" - }, - { - "name": "delphi_fork_exponential_sinusoidal_partcov", - "model_type": "delphi_fork", - "loss_type": "exponential", - "full_cov": false, - "checkpoint_path": "runs/delphi_fork_exponential_sinusoidal_partcov_20260109-222502/best_model.pt" - }, - { - "name": "sap_delphi_discrete_time_cif_mlp_fullcov", - "model_type": "sap_delphi", - "loss_type": "discrete_time_cif", - "full_cov": true, - "checkpoint_path": "runs/sap_delphi_discrete_time_cif_mlp_fullcov_20260110-010514/best_model.pt" - }, - { - "name": "sap_delphi_discrete_time_cif_mlp_partcov", - "model_type": "sap_delphi", - "loss_type": "discrete_time_cif", - "full_cov": false, - "checkpoint_path": "runs/sap_delphi_discrete_time_cif_mlp_partcov_20260110-005804/best_model.pt" - }, - { - "name": "sap_delphi_discrete_time_cif_sinusoidal_fullcov", - "model_type": "sap_delphi", - "loss_type": "discrete_time_cif", - "full_cov": true, - "checkpoint_path": "runs/sap_delphi_discrete_time_cif_sinusoidal_fullcov_20260109-222502/best_model.pt" - }, - { - "name": "sap_delphi_discrete_time_cif_sinusoidal_partcov", - "model_type": "sap_delphi", - "loss_type": "discrete_time_cif", - "full_cov": false, - "checkpoint_path": "runs/sap_delphi_discrete_time_cif_sinusoidal_partcov_20260109-222501/best_model.pt" - }, - { - "name": "sap_delphi_exponential_mlp_fullcov", - "model_type": "sap_delphi", - "loss_type": "exponential", - "full_cov": true, - "checkpoint_path": "runs/sap_delphi_exponential_mlp_fullcov_20260110-042422/best_model.pt" - }, - { - "name": "sap_delphi_exponential_mlp_partcov", - "model_type": "sap_delphi", - "loss_type": "exponential", - "full_cov": false, - "checkpoint_path": "runs/sap_delphi_exponential_mlp_partcov_20260110-041850/best_model.pt" - }, - { - "name": "sap_delphi_exponential_sinusoidal_fullcov", - "model_type": "sap_delphi", - "loss_type": "exponential", - "full_cov": true, - "checkpoint_path": "runs/sap_delphi_exponential_sinusoidal_fullcov_20260109-222502/best_model.pt" - }, - { - "name": "sap_delphi_exponential_sinusoidal_partcov", - "model_type": "sap_delphi", - "loss_type": "exponential", - "full_cov": false, - "checkpoint_path": "runs/sap_delphi_exponential_sinusoidal_partcov_20260109-222502/best_model.pt" - } -] \ No newline at end of file