import argparse import csv import json import math import os import random import sys import time from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader, random_split from dataset import HealthDataset, health_collate_fn from model import DelphiFork, SapDelphi, SimpleHead # ============================================================ # Constants / defaults (aligned with evaluate_prompt.md) # ============================================================ DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")] DEFAULT_EVAL_HORIZONS = [0.25, 0.5, 1.0, 2.0, 5.0, 10.0] DAYS_PER_YEAR = 365.25 DEFAULT_DEATH_CAUSE_ID = 1256 # ============================================================ # Model specs # ============================================================ @dataclass(frozen=True) class ModelSpec: name: str model_type: str # delphi_fork | sap_delphi loss_type: str # exponential | discrete_time_cif | lognormal_basis_binned_hazard_cif full_cov: bool checkpoint_path: str # ============================================================ # Determinism # ============================================================ def set_deterministic(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # ============================================================ # Utilities # ============================================================ def _parse_bool(x: Any) -> bool: if isinstance(x, bool): return x s = str(x).strip().lower() if s in {"true", "1", "yes", "y"}: return True if s in {"false", "0", "no", "n"}: return False raise ValueError(f"Cannot parse boolean: {x!r}") def load_models_json(path: str) -> List[ModelSpec]: with open(path, "r") as f: data = json.load(f) if not isinstance(data, list): raise ValueError("models_json must be a list of model entries") specs: List[ModelSpec] = [] for row in data: specs.append( ModelSpec( name=str(row["name"]), model_type=str(row["model_type"]), loss_type=str(row["loss_type"]), full_cov=_parse_bool(row["full_cov"]), checkpoint_path=str(row["checkpoint_path"]), ) ) return specs def load_train_config_for_checkpoint(checkpoint_path: str) -> Dict[str, Any]: run_dir = os.path.dirname(os.path.abspath(checkpoint_path)) cfg_path = os.path.join(run_dir, "train_config.json") with open(cfg_path, "r") as f: cfg = json.load(f) return cfg def build_eval_subset( dataset: HealthDataset, train_ratio: float, val_ratio: float, seed: int, split: str, ): n_total = len(dataset) n_train = int(n_total * train_ratio) n_val = int(n_total * val_ratio) n_test = n_total - n_train - n_val train_ds, val_ds, test_ds = random_split( dataset, [n_train, n_val, n_test], generator=torch.Generator().manual_seed(seed), ) if split == "train": return train_ds if split == "val": return val_ds if split == "test": return test_ds if split == "all": return dataset raise ValueError("split must be one of: train, val, test, all") # ============================================================ # Context selection (anti-leakage) # ============================================================ def select_context_indices( event_seq: torch.Tensor, time_seq: torch.Tensor, offset_years: float, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Select per-sample prediction context index. IMPORTANT SEMANTICS: - The last observed token time is treated as the FOLLOW-UP END time. - We pick the last valid token with time <= (followup_end_time - offset). - We do NOT interpret followup_end_time as an event time. Returns: keep_mask: (B,) bool, which samples have a valid context t_ctx: (B,) long, index into sequence t_ctx_time: (B,) float, time (days) at context """ # valid tokens are event != 0 (padding is 0) valid = event_seq != 0 lengths = valid.sum(dim=1) last_idx = torch.clamp(lengths - 1, min=0) b = torch.arange(event_seq.size(0), device=event_seq.device) followup_end_time = time_seq[b, last_idx] t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR) eligible = valid & (time_seq <= t_cut.unsqueeze(1)) eligible_counts = eligible.sum(dim=1) keep = eligible_counts > 0 t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long) t_ctx_time = time_seq[b, t_ctx] return keep, t_ctx, t_ctx_time def next_event_after_context( event_seq: torch.Tensor, time_seq: torch.Tensor, t_ctx: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Return next disease event after context. Returns: dt_years: (B,) float, time to next disease in years; +inf if none cause: (B,) long, disease id in [0,K) for next event; -1 if none """ B, L = event_seq.shape b = torch.arange(B, device=event_seq.device) t0 = time_seq[b, t_ctx] # Allow same-day events while excluding the context token itself. # We rely on time-sorted sequences and select the FIRST valid future event by index. idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) idx_min = torch.where( future, idxs, torch.full_like(idxs, L)).min(dim=1).values has = idx_min < L t_next = torch.where(has, idx_min, torch.zeros_like(idx_min)) t_next_time = time_seq[b, t_next] dt_days = t_next_time - t0 dt_years = dt_days / DAYS_PER_YEAR dt_years = torch.where( has, dt_years, torch.full_like(dt_years, float("inf"))) cause_token = event_seq[b, t_next] cause = (cause_token - 2).to(torch.long) cause = torch.where(has, cause, torch.full_like(cause, -1)) return dt_years, cause def multi_hot_ever_within_horizon( event_seq: torch.Tensor, time_seq: torch.Tensor, t_ctx: torch.Tensor, tau_years: float, n_disease: int, ) -> torch.Tensor: """Binary labels: disease k occurs within tau after context (any occurrence).""" B, L = event_seq.shape b = torch.arange(B, device=event_seq.device) t0 = time_seq[b, t_ctx] t1 = t0 + (tau_years * DAYS_PER_YEAR) idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) # Include same-day events after context, exclude any token at/before context index. in_window = ( (idxs > t_ctx.unsqueeze(1)) & (time_seq >= t0.unsqueeze(1)) & (time_seq <= t1.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) ) if not in_window.any(): return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) b_idx, t_idx = in_window.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) y[b_idx, disease_ids] = True return y def multi_hot_ever_after_context_anytime( event_seq: torch.Tensor, t_ctx: torch.Tensor, n_disease: int, ) -> torch.Tensor: """Binary labels: disease k occurs ANYTIME after the prediction context. This is Delphi2M-compatible for Task A case/control definition. Same-day events are included as long as they occur after the context token index. """ B, L = event_seq.shape idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) if not future.any(): return y b_idx, t_idx = future.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) y[b_idx, disease_ids] = True return y def multi_hot_selected_causes_within_horizon( event_seq: torch.Tensor, time_seq: torch.Tensor, t_ctx: torch.Tensor, tau_years: float, cause_ids: torch.Tensor, n_disease: int, ) -> torch.Tensor: """Labels for selected causes only: does cause k occur within tau after context?""" B, L = event_seq.shape device = event_seq.device b = torch.arange(B, device=device) t0 = time_seq[b, t_ctx] t1 = t0 + (tau_years * DAYS_PER_YEAR) idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) in_window = ( (idxs > t_ctx.unsqueeze(1)) & (time_seq >= t0.unsqueeze(1)) & (time_seq <= t1.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) ) out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device) if not in_window.any(): return out b_idx, t_idx = in_window.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) # Filter to selected causes via a boolean membership mask over the global disease space. selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device) selected[cause_ids] = True keep = selected[disease_ids] if not keep.any(): return out b_idx = b_idx[keep] disease_ids = disease_ids[keep] # Map disease_id -> local index in cause_ids # Build a lookup table (global disease space) where lookup[disease_id] = local_index lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device) lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device) local = lookup[disease_ids] out[b_idx, local] = True return out # ============================================================ # CIF conversion # ============================================================ def cifs_from_exponential_logits( logits: torch.Tensor, taus: Sequence[float], eps: float = 1e-6, return_survival: bool = False, ) -> torch.Tensor: """Convert exponential cause-specific logits -> CIFs at taus. logits: (B, K) returns: (B, K, H) or (cif, survival) if return_survival """ hazards = F.softplus(logits) + eps total = hazards.sum(dim=1, keepdim=True) # (B,1) taus_t = torch.tensor(list(taus), device=logits.device, dtype=hazards.dtype).view(1, 1, -1) total_h = total.unsqueeze(-1) # (B,1,1) # (1 - exp(-Lambda * tau)) one_minus_surv = 1.0 - torch.exp(-total_h * taus_t) frac = hazards / torch.clamp(total, min=eps) cif = frac.unsqueeze(-1) * one_minus_surv # (B,K,H) # If total==0, set to 0 cif = torch.where(total_h > 0, cif, torch.zeros_like(cif)) if not return_survival: return cif survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H) # Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors. nonzero = (total.squeeze(1) > 0).unsqueeze(1) survival = torch.where(nonzero, survival, torch.ones_like(survival)) return cif, survival def cifs_from_discrete_time_logits( logits: torch.Tensor, bin_edges: Sequence[float], taus: Sequence[float], return_survival: bool = False, ) -> torch.Tensor: """Convert discrete-time CIF logits -> CIFs at taus. logits: (B, K+1, n_bins+1) bin_edges: len=n_bins+1 (including 0 and inf) taus: subset of finite bin edges (recommended) returns: (B, K, H) or (cif, survival) if return_survival """ if logits.ndim != 3: raise ValueError("Expected logits shape (B, K+1, n_bins+1)") B, K_plus_1, n_bins_plus_1 = logits.shape K = K_plus_1 - 1 edges = [float(x) for x in bin_edges] # drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf finite_edges = [e for e in edges[1:] if math.isfinite(e)] n_bins = len(finite_edges) if n_bins_plus_1 != len(edges): raise ValueError("logits last dim must match len(bin_edges)") probs = torch.softmax(logits, dim=1) # (B, K+1, n_bins+1) # use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot) hazards = probs[:, :K, 1: 1 + n_bins] # (B,K,n_bins) p_comp = probs[:, K, 1: 1 + n_bins] # (B,n_bins) # survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v] ones = torch.ones((B, 1), device=logits.device, dtype=probs.dtype) cum = torch.cumprod(p_comp, dim=1) s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # (B, n_bins) cif_bins = torch.cumsum(s_prev.unsqueeze( 1) * hazards, dim=2) # (B,K,n_bins) # Robust mapping from tau -> edge index (floating-point safe). # taus are expected to align with bin edges, but may differ slightly due to parsing/serialization. finite_edges_arr = np.asarray(finite_edges, dtype=float) tau_to_idx: List[int] = [] for tau in taus: tau_f = float(tau) if not math.isfinite(tau_f): raise ValueError("taus must be finite for discrete-time CIF") diffs = np.abs(finite_edges_arr - tau_f) j = int(np.argmin(diffs)) if diffs[j] > 1e-6: raise ValueError( f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})" ) tau_to_idx.append(j) idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long) cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H) if not return_survival: return cif # Survival at each horizon = prod_{u <= idx[h]} p_comp[u] survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v] survival = survival_bins.index_select(dim=1, index=idx) # (B,H) return cif, survival def _normal_cdf_stable(z: torch.Tensor) -> torch.Tensor: z = torch.clamp(z, -12.0, 12.0) return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0))) def cifs_from_lognormal_basis_binned_hazard_logits( logits: torch.Tensor, *, centers: Sequence[float], sigma: torch.Tensor, bin_edges: Sequence[float], taus: Sequence[float], eps: float = 1e-8, alpha_floor: float = 0.0, return_survival: bool = False, ) -> torch.Tensor: """Convert Route-3 binned hazard logits -> CIFs at taus. logits: (B, J, R) OR (B, J*R) OR (B, 1+J*R) (leading column ignored). taus are expected to align with finite bin edges. """ if logits.ndim not in {2, 3}: raise ValueError("logits must be 2D or 3D") if sigma.ndim != 0: raise ValueError("sigma must be a scalar tensor") device = logits.device dtype = logits.dtype centers_t = torch.tensor([float(x) for x in centers], device=device, dtype=dtype) r = int(centers_t.numel()) if r <= 0: raise ValueError("centers must be non-empty") offset = 0 if logits.ndim == 3: j = int(logits.shape[1]) if int(logits.shape[2]) != r: raise ValueError( f"logits.shape[2] must equal R={r}; got {int(logits.shape[2])}" ) else: d = int(logits.shape[1]) if d % r == 0: jr = d elif (d - 1) % r == 0: offset = 1 jr = d - 1 else: raise ValueError( f"logits.shape[1] must be divisible by R={r} (or 1+J*R); got {d}") j = jr // r if j <= 0: raise ValueError("Inferred J must be >= 1") edges = [float(x) for x in bin_edges] finite_edges = [e for e in edges[1:] if math.isfinite(e)] n_bins = len(finite_edges) if n_bins <= 0: raise ValueError("bin_edges must contain at least one finite edge") # Build finite bins [edges[k-1], edges[k]) for k=1..n_bins left = torch.tensor(edges[:n_bins], device=device, dtype=dtype) right = torch.tensor(edges[1:1 + n_bins], device=device, dtype=dtype) # Stable t_min clamp (aligns with training loss rule). t_min = 1e-12 if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0: t_min = edges[1] * 1e-6 t_min_t = torch.tensor(float(t_min), device=device, dtype=dtype) left_is_zero = left <= 0 left_clamped = torch.clamp(left, min=t_min_t) log_left = torch.log(left_clamped) right_clamped = torch.clamp(right, min=t_min_t) log_right = torch.log(right_clamped) sigma_c = sigma.to(device=device, dtype=dtype) z_left = (log_left.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c z_right = (log_right.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c cdf_left = _normal_cdf_stable(z_left) if left_is_zero.any(): cdf_left = torch.where(left_is_zero.unsqueeze(-1), torch.zeros_like(cdf_left), cdf_left) cdf_right = _normal_cdf_stable(z_right) delta_basis = torch.clamp(cdf_right - cdf_left, min=0.0) # (n_bins, R) if logits.ndim == 3: alpha = F.softplus(logits) + float(alpha_floor) # (B,J,R) else: logits_used = logits[:, offset:] alpha = (F.softplus(logits_used) + float(alpha_floor) ).view(logits.size(0), j, r) # (B,J,R) h_jk = torch.einsum("bjr,kr->bjk", alpha, delta_basis) # (B,J,n_bins) h_k = h_jk.sum(dim=1) # (B,n_bins) h_k = torch.clamp(h_k, min=eps) h_jk = torch.clamp(h_jk, min=eps) p_comp = torch.exp(-h_k) # (B,n_bins) one_minus = -torch.expm1(-h_k) # (B,n_bins) = 1-exp(-H) ratio = h_jk / torch.clamp(h_k.unsqueeze(1), min=eps) p_event = one_minus.unsqueeze(1) * ratio # (B,J,n_bins) ones = torch.ones((alpha.size(0), 1), device=device, dtype=dtype) cum = torch.cumprod(p_comp, dim=1) # survival after each bin s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # survival before each bin cif_bins = torch.cumsum(s_prev.unsqueeze( 1) * p_event, dim=2) # (B,J,n_bins) finite_edges_arr = np.asarray(finite_edges, dtype=float) tau_to_idx: List[int] = [] for tau in taus: tau_f = float(tau) if not math.isfinite(tau_f): raise ValueError("taus must be finite for discrete-time CIF") diffs = np.abs(finite_edges_arr - tau_f) idx0 = int(np.argmin(diffs)) if diffs[idx0] > 1e-6: raise ValueError( f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[idx0]})" ) tau_to_idx.append(idx0) idx = torch.tensor(tau_to_idx, device=device, dtype=torch.long) cif = cif_bins.index_select(dim=2, index=idx) # (B,J,H) if not return_survival: return cif survival = cum.index_select(dim=1, index=idx) # (B,H) return cif, survival # ============================================================ # CIF integrity checks # ============================================================ def check_cif_integrity( cause_cif: np.ndarray, horizons: Sequence[float], *, tol: float = 1e-6, name: str = "", strict: bool = False, survival: Optional[np.ndarray] = None, ) -> Tuple[bool, List[str]]: """Run basic sanity checks on CIF arrays. Args: cause_cif: (N, K, H) horizons: length H tol: tolerance for inequalities name: model name for messages strict: if True, raise ValueError on first failure survival: optional (N, H) survival values at the same horizons Returns: (integrity_ok, notes) """ notes: List[str] = [] model_tag = f"[{name}] " if name else "" def _fail(msg: str) -> None: full = model_tag + msg if strict: raise ValueError(full) print("WARNING:", full) notes.append(msg) cif = np.asarray(cause_cif) if cif.ndim != 3: _fail(f"integrity: expected cause_cif ndim=3, got {cif.ndim}") return False, notes N, K, H = cif.shape if H != len(horizons): _fail( f"integrity: horizon length mismatch (H={H}, len(horizons)={len(horizons)})") # (5) Finite if not np.isfinite(cif).all(): _fail("integrity: non-finite values (NaN/Inf) in cause_cif") # (1) Range cmin = float(np.nanmin(cif)) cmax = float(np.nanmax(cif)) if cmin < -tol: _fail(f"integrity: range min={cmin} < -tol={-tol}") if cmax > 1.0 + tol: _fail(f"integrity: range max={cmax} > 1+tol={1.0+tol}") # (2) Monotonicity in horizons (per n,k) diffs = np.diff(cif, axis=2) if diffs.size > 0: if np.nanmin(diffs) < -tol: _fail("integrity: monotonicity violated (found negative diff along horizons)") # (3) Probability mass: sum_k CIF <= 1 mass = np.sum(cif, axis=1) # (N,H) mass_max = float(np.nanmax(mass)) if mass_max > 1.0 + tol: _fail(f"integrity: probability mass exceeds 1 (max sum_k={mass_max})") # (4) Conservation with survival, if provided if survival is None: warn = "integrity: survival not provided; skipping conservation check" if strict: # still skip (requested behavior), but keep message for context notes.append(warn) else: print("WARNING:", model_tag + warn) notes.append(warn) else: s = np.asarray(survival, dtype=float) if s.shape != (N, H): _fail( f"integrity: survival shape mismatch (got {s.shape}, expected {(N, H)})") else: recon = 1.0 - s err = np.abs(recon - mass) # Discrete-time should be very tight; exponential may accumulate slightly more numerical error. tol_cons = max(float(tol), 1e-4) if float(np.nanmax(err)) > tol_cons: _fail( f"integrity: conservation violated (max |(1-surv)-sum_cif|={float(np.nanmax(err))}, tol={tol_cons})") ok = len([n for n in notes if not n.endswith( "skipping conservation check")]) == 0 return ok, notes # ============================================================ # Metrics # ============================================================ # --- Rank-based ROC AUC (ties handled via midranks) --- def compute_midrank(x: np.ndarray) -> np.ndarray: """Vectorized midrank computation (ties -> average ranks).""" x = np.asarray(x, dtype=float) n = int(x.shape[0]) if n == 0: return np.asarray([], dtype=float) order = np.argsort(x, kind="mergesort") z = x[order] # Find tie groups in sorted order. diff = np.diff(z) # boundaries includes 0 and n boundaries = np.concatenate( [np.array([0], dtype=int), np.nonzero(diff != 0) [0] + 1, np.array([n], dtype=int)] ) starts = boundaries[:-1] ends = boundaries[1:] lens = ends - starts # Midrank for each group in 1-based rank space. mids = 0.5 * (starts + ends - 1) + 1.0 t_sorted = np.repeat(mids, lens).astype(float, copy=False) out = np.empty(n, dtype=float) out[order] = t_sorted return out def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float: """Rank-based ROC AUC via Mann–Whitney U statistic (ties handled by midranks). Returns NaN for degenerate labels. """ y = np.asarray(y_true, dtype=int) s = np.asarray(y_score, dtype=float) if y.ndim != 1 or s.ndim != 1 or y.shape[0] != s.shape[0]: raise ValueError("roc_auc_rank expects 1D arrays of equal length") m = int(np.sum(y == 1)) n = int(np.sum(y == 0)) if m == 0 or n == 0: return float("nan") ranks = compute_midrank(s) sum_pos = float(np.sum(ranks[y == 1])) auc = (sum_pos - m * (m + 1) / 2.0) / (m * n) return float(auc) def brier_score(p: np.ndarray, y: np.ndarray) -> float: p = np.asarray(p, dtype=float) y = np.asarray(y, dtype=float) return float(np.mean((p - y) ** 2)) def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]: bins, ici = _calibration_bins_and_ici( p, y, n_bins=int(n_bins), return_bins=True) return {"bins": bins, "ici": float(ici)} def calibration_ici_only(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> float: """Fast ICI only (no per-bin point export).""" _, ici = _calibration_bins_and_ici( p, y, n_bins=int(n_bins), return_bins=False) return float(ici) def _calibration_bins_and_ici( p: np.ndarray, y: np.ndarray, *, n_bins: int, return_bins: bool, ) -> Tuple[List[Dict[str, Any]], float]: """Vectorized quantile binning for calibration + ICI.""" p = np.asarray(p, dtype=float) y = np.asarray(y, dtype=float) if p.size == 0: return ([], float("nan")) if return_bins else ([], float("nan")) q = np.linspace(0.0, 1.0, int(n_bins) + 1) edges = np.quantile(p, q) edges = np.asarray(edges, dtype=float) edges[0] = -np.inf edges[-1] = np.inf # Bin assignment: i if edges[i] < p <= edges[i+1] bin_idx = np.searchsorted(edges, p, side="right") - 1 bin_idx = np.clip(bin_idx, 0, int(n_bins) - 1) counts = np.bincount(bin_idx, minlength=int(n_bins)).astype(float) sum_p = np.bincount(bin_idx, weights=p, minlength=int(n_bins)).astype(float) sum_y = np.bincount(bin_idx, weights=y, minlength=int(n_bins)).astype(float) nonempty = counts > 0 if not np.any(nonempty): return ([], float("nan")) if return_bins else ([], float("nan")) p_mean = np.zeros(int(n_bins), dtype=float) y_mean = np.zeros(int(n_bins), dtype=float) p_mean[nonempty] = sum_p[nonempty] / counts[nonempty] y_mean[nonempty] = sum_y[nonempty] / counts[nonempty] diffs = np.abs(p_mean[nonempty] - y_mean[nonempty]) ici = float(np.mean(diffs)) if diffs.size else float("nan") if not return_bins: return [], ici bins: List[Dict[str, Any]] = [] idxs = np.nonzero(nonempty)[0] for i in idxs.tolist(): bins.append( { "bin": int(i), "p_mean": float(p_mean[i]), "y_mean": float(y_mean[i]), "n": int(counts[i]), } ) return bins, ici def _progress_line(done: int, total: int, prefix: str = "") -> str: total_i = max(1, int(total)) done_i = max(0, min(int(done), total_i)) width = 28 frac = done_i / total_i filled = int(round(width * frac)) bar = "#" * filled + "-" * (width - filled) pct = 100.0 * frac return f"{prefix}[{bar}] {done_i}/{total_i} ({pct:5.1f}%)" def _should_show_progress(mode: str) -> bool: m = str(mode).strip().lower() if m in {"0", "false", "no", "none", "off"}: return False # Default: show if interactive. if m in {"auto", "1", "true", "yes", "on", "bar"}: try: return bool(sys.stdout.isatty()) except Exception: return True return True def _safe_float(x: Any, default: float = float("nan")) -> float: try: return float(x) except Exception: return float(default) def _ensure_dir(path: str) -> None: os.makedirs(path, exist_ok=True) def load_cause_names(path: str = "labels.csv") -> Dict[int, str]: """Load 0-based cause_id -> name mapping. labels.csv is assumed to be one label per line, in disease-id order. """ if not os.path.exists(path): return {} mapping: Dict[int, str] = {} with open(path, "r", encoding="utf-8") as f: for i, line in enumerate(f): name = line.strip() if name: mapping[int(i)] = name return mapping def pick_focus_causes( *, counts_within_tau: Optional[np.ndarray], n_disease: int, death_cause_id: int = DEFAULT_DEATH_CAUSE_ID, k: int = 5, ) -> List[int]: """Pick focus causes for user-facing evaluation. Rule: 1) Always include death_cause_id first. 2) Then add K additional causes by descending event count if available. If counts_within_tau is None, fall back to descending cause_id coverage proxy. Notes: - counts_within_tau is expected to be shape (n_disease,). - Deterministic: ties broken by smaller cause id. """ n_disease_i = int(n_disease) if death_cause_id < 0 or death_cause_id >= n_disease_i: print( f"WARNING: death_cause_id={death_cause_id} out of range (n_disease={n_disease_i}); " "it will be omitted from focus causes." ) focus: List[int] = [] else: focus = [int(death_cause_id)] candidates = [i for i in range(n_disease_i) if i != int(death_cause_id)] if counts_within_tau is not None: c = np.asarray(counts_within_tau).astype(float) if c.shape[0] != n_disease_i: print( "WARNING: counts_within_tau length mismatch; falling back to coverage proxy ordering." ) counts_within_tau = None else: # Sort by (-count, cause_id) order = sorted(candidates, key=lambda i: (-float(c[i]), int(i))) order = [i for i in order if float(c[i]) > 0] focus.extend([int(i) for i in order[: int(k)]]) if counts_within_tau is None: # Fallback: deterministic coverage proxy (descending id, excluding death), then take K. # (Real coverage requires data; this path is mostly for robustness.) order = sorted(candidates, key=lambda i: (-int(i))) focus.extend([int(i) for i in order[: int(k)]]) # De-dup while preserving order seen = set() out: List[int] = [] for cid in focus: if cid not in seen: out.append(cid) seen.add(cid) return out def write_simple_csv(path: str, fieldnames: List[str], rows: List[Dict[str, Any]]) -> None: _ensure_dir(os.path.dirname(os.path.abspath(path)) or ".") with open(path, "w", newline="", encoding="utf-8") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for r in rows: w.writerow(r) def _sex_slices(sex: Optional[np.ndarray]) -> List[Tuple[str, Optional[np.ndarray]]]: """Return list of (sex_label, mask) slices including an 'all' slice. If sex is missing, returns only ('all', None). """ out: List[Tuple[str, Optional[np.ndarray]]] = [("all", None)] if sex is None: return out s = np.asarray(sex) if s.ndim != 1: return out for val in [0, 1]: m = (s == val) if int(np.sum(m)) > 0: out.append((str(val), m)) return out def _quantile_edges(p: np.ndarray, q: int) -> np.ndarray: edges = np.quantile(p, np.linspace(0.0, 1.0, int(q) + 1)) edges = np.asarray(edges, dtype=float) edges[0] = -np.inf edges[-1] = np.inf return edges def compute_risk_stratification_bins( p: np.ndarray, y: np.ndarray, *, q_default: int = 10, ) -> Tuple[int, List[Dict[str, Any]], Dict[str, Any]]: """Compute quantile-based risk strata and a compact summary.""" p = np.asarray(p, dtype=float) y = np.asarray(y, dtype=float) n = int(p.shape[0]) if n == 0: return 0, [], { "y_overall": float("nan"), "top_decile_y_rate": float("nan"), "bottom_half_y_rate": float("nan"), "lift_top10_vs_bottom50": float("nan"), "slope_pred_vs_obs": float("nan"), } # Choose quantiles robustly. q = int(q_default) if n < 200: q = 5 edges = _quantile_edges(p, q) y_overall = float(np.mean(y)) bin_rows: List[Dict[str, Any]] = [] p_means: List[float] = [] y_rates: List[float] = [] n_bins: List[int] = [] for i in range(q): mask = (p > edges[i]) & (p <= edges[i + 1]) nb = int(np.sum(mask)) if nb == 0: # Keep the row for consistent plotting; set NaNs. bin_rows.append( { "q": int(i + 1), "n_bin": 0, "p_mean": float("nan"), "y_rate": float("nan"), "y_overall": y_overall, "lift_vs_overall": float("nan"), } ) continue p_mean = float(np.mean(p[mask])) y_rate = float(np.mean(y[mask])) lift = float(y_rate / y_overall) if y_overall > 0 else float("nan") bin_rows.append( { "q": int(i + 1), "n_bin": nb, "p_mean": p_mean, "y_rate": y_rate, "y_overall": y_overall, "lift_vs_overall": lift, } ) p_means.append(p_mean) y_rates.append(y_rate) n_bins.append(nb) # Summary top_mask = (p > edges[q - 1]) & (p <= edges[q]) bot_half_mask = (p > edges[0]) & (p <= edges[q // 2]) top_y = float(np.mean(y[top_mask])) if int( np.sum(top_mask)) > 0 else float("nan") bot_y = float(np.mean(y[bot_half_mask])) if int( np.sum(bot_half_mask)) > 0 else float("nan") lift_top_vs_bottom = float(top_y / bot_y) if (np.isfinite(top_y) and np.isfinite(bot_y) and bot_y > 0) else float("nan") slope = float("nan") if len(p_means) >= 2: # Weighted least squares slope of y_rate ~ p_mean. x = np.asarray(p_means, dtype=float) yy = np.asarray(y_rates, dtype=float) w = np.asarray(n_bins, dtype=float) xm = float(np.average(x, weights=w)) ym = float(np.average(yy, weights=w)) denom = float(np.sum(w * (x - xm) ** 2)) if denom > 0: slope = float(np.sum(w * (x - xm) * (yy - ym)) / denom) summary = { "y_overall": y_overall, "top_decile_y_rate": top_y, "bottom_half_y_rate": bot_y, "lift_top10_vs_bottom50": lift_top_vs_bottom, "slope_pred_vs_obs": slope, } return q, bin_rows, summary def compute_capture_points( p: np.ndarray, y: np.ndarray, k_pcts: Sequence[int], ) -> List[Dict[str, Any]]: p = np.asarray(p, dtype=float) y = np.asarray(y, dtype=float) n = int(p.shape[0]) if n == 0: return [] order = np.argsort(-p) y_sorted = y[order] events_total = float(np.sum(y_sorted)) rows: List[Dict[str, Any]] = [] for k in k_pcts: kf = float(k) n_targeted = int(math.ceil(n * kf / 100.0)) n_targeted = max(1, min(n_targeted, n)) events_targeted = float(np.sum(y_sorted[:n_targeted])) capture = float(events_targeted / events_total) if events_total > 0 else float("nan") precision = float(events_targeted / float(n_targeted)) rows.append( { "k_pct": int(k), "n_targeted": int(n_targeted), "events_targeted": float(events_targeted), "events_total": float(events_total), "event_capture_rate": capture, "precision_in_targeted": precision, } ) return rows def compute_event_rate_at_topk_causes( p_tau: np.ndarray, y_tau: np.ndarray, topk_list: Sequence[int], ) -> List[Dict[str, Any]]: """Compute Event Rate@K for cross-cause prioritization. For each individual, rank causes by predicted risk p_tau at a fixed horizon. For each K, select top-K causes and compute the fraction that occur within the horizon. Args: p_tau: (N, K) predicted CIFs at a fixed horizon y_tau: (N, K) binary labels (0/1) whether cause occurs within the horizon topk_list: list of K values to evaluate Returns: List of rows with: - topk - event_rate_mean / event_rate_median - recall_mean / recall_median (averaged over individuals with >=1 true cause) - n_total / n_valid_recall """ p = np.asarray(p_tau, dtype=float) y = np.asarray(y_tau, dtype=float) if p.ndim != 2 or y.ndim != 2 or p.shape != y.shape: raise ValueError( "compute_event_rate_at_topk_causes expects (N,K) arrays of equal shape") n, k_total = p.shape if n == 0 or k_total == 0: out: List[Dict[str, Any]] = [] for kk in topk_list: out.append( { "topk": int(max(1, int(kk))), "event_rate_mean": float("nan"), "event_rate_median": float("nan"), "recall_mean": float("nan"), "recall_median": float("nan"), "n_total": int(n), "n_valid_recall": 0, } ) return out # Sanitize K list. topks = sorted({int(x) for x in topk_list if int(x) > 0}) if not topks: return [] max_k = min(int(max(topks)), int(k_total)) if max_k <= 0: return [] # Efficient: get top max_k causes per individual, then sort within those. part = np.argpartition(-p, kth=max_k - 1, axis=1)[:, :max_k] # (N, max_k) p_part = np.take_along_axis(p, part, axis=1) order = np.argsort(-p_part, axis=1) top_sorted = np.take_along_axis(part, order, axis=1) # (N, max_k) out_rows: List[Dict[str, Any]] = [] for kk in topks: kk_eff = min(int(kk), int(k_total)) idx = top_sorted[:, :kk_eff] y_sel = np.take_along_axis(y, idx, axis=1) # Selected true causes per person hit = np.sum(y_sel, axis=1) # Precision-like: fraction of selected causes that occur per_person = hit / \ float(kk_eff) if kk_eff > 0 else np.full((n,), np.nan) # Recall@K: fraction of true causes covered by top-K (undefined when no true cause) g = np.sum(y, axis=1) valid = g > 0 recall = np.full((n,), np.nan, dtype=float) recall[valid] = hit[valid] / g[valid] out_rows.append( { "topk": int(kk_eff), "event_rate_mean": float(np.mean(per_person)) if per_person.size else float("nan"), "event_rate_median": float(np.median(per_person)) if per_person.size else float("nan"), "recall_mean": float(np.nanmean(recall)) if int(np.sum(valid)) > 0 else float("nan"), "recall_median": float(np.nanmedian(recall)) if int(np.sum(valid)) > 0 else float("nan"), "n_total": int(n), "n_valid_recall": int(np.sum(valid)), } ) return out_rows def compute_random_ranking_baseline_topk( y_tau: np.ndarray, topk_list: Sequence[int], *, z: float = 1.645, ) -> List[Dict[str, Any]]: """Random ranking baseline for Event Rate@K and Recall@K. Baseline definition: - For each individual, pick K causes uniformly at random without replacement. - EventRate@K = (# selected causes that occur) / K. - Recall@K = (# selected causes that occur) / (# causes that occur), averaged over individuals with >=1 true cause. This function computes the expected baseline mean and an approximate 5-95% range for the population mean using a normal approximation of the hypergeometric variance. Args: y_tau: (N, K_total) binary labels topk_list: K values z: z-score for the central interval; z=1.645 corresponds to ~90% (5-95%) Returns: Rows with baseline means and p05/p95 for both metrics. """ y = np.asarray(y_tau, dtype=float) if y.ndim != 2: raise ValueError( "compute_random_ranking_baseline_topk expects y_tau with shape (N,K)") n, k_total = y.shape topks = sorted({int(x) for x in topk_list if int(x) > 0}) if not topks: return [] g = np.sum(y, axis=1) # (N,) valid = g > 0 n_valid = int(np.sum(valid)) out: List[Dict[str, Any]] = [] for kk in topks: kk_eff = min(int(kk), int(k_total)) if k_total > 0 else int(kk) if n == 0 or k_total == 0 or kk_eff <= 0: out.append( { "topk": int(max(1, kk_eff)), "baseline_event_rate_mean": float("nan"), "baseline_event_rate_p05": float("nan"), "baseline_event_rate_p95": float("nan"), "baseline_recall_mean": float("nan"), "baseline_recall_p05": float("nan"), "baseline_recall_p95": float("nan"), "n_total": int(n), "n_valid_recall": int(n_valid), "k_total": int(k_total), "baseline_method": "random_ranking_hypergeometric_normal_approx", } ) continue # Expected EventRate@K per person is E[X]/K = (K * (g/K_total))/K = g/K_total. er_mean = float(np.mean(g / float(k_total))) # Variance of hypergeometric count X: # Var(X) = K * p * (1-p) * ((K_total - K)/(K_total - 1)), where p=g/K_total. if k_total > 1 and kk_eff < k_total: p = g / float(k_total) finite_corr = (float(k_total - kk_eff) / float(k_total - 1)) var_x = float(kk_eff) * p * (1.0 - p) * finite_corr else: var_x = np.zeros_like(g, dtype=float) var_er = var_x / (float(kk_eff) ** 2) se_er_mean = float(np.sqrt(np.sum(var_er))) / float(max(1, n)) er_p05 = float(np.clip(er_mean - z * se_er_mean, 0.0, 1.0)) er_p95 = float(np.clip(er_mean + z * se_er_mean, 0.0, 1.0)) # Expected Recall@K for individuals with g>0 is K/K_total (clipped). rec_mean = float(min(float(kk_eff) / float(k_total), 1.0)) if n_valid > 0: var_rec = np.zeros_like(g, dtype=float) gv = g[valid] var_xv = var_x[valid] # Var( X / g ) = Var(X) / g^2 (approx; g is fixed per individual) var_rec_v = var_xv / (gv ** 2) se_rec_mean = float(np.sqrt(np.sum(var_rec_v))) / float(n_valid) rec_p05 = float(np.clip(rec_mean - z * se_rec_mean, 0.0, 1.0)) rec_p95 = float(np.clip(rec_mean + z * se_rec_mean, 0.0, 1.0)) else: rec_p05 = float("nan") rec_p95 = float("nan") out.append( { "topk": int(kk_eff), "baseline_event_rate_mean": er_mean, "baseline_event_rate_p05": er_p05, "baseline_event_rate_p95": er_p95, "baseline_recall_mean": rec_mean, "baseline_recall_p05": float(rec_p05), "baseline_recall_p95": float(rec_p95), "n_total": int(n), "n_valid_recall": int(n_valid), "k_total": int(k_total), "baseline_method": "random_ranking_hypergeometric_normal_approx", } ) return out def count_occurs_within_horizon( loader: DataLoader, offset_years: float, tau_years: float, n_disease: int, device: str, ) -> Tuple[np.ndarray, int]: """Count per-person occurrence within tau after the prediction context. Returns counts[k] = number of individuals with disease k at least once in (t_ctx, t_ctx+tau]. """ counts = torch.zeros((n_disease,), dtype=torch.long, device=device) n_total_eval = 0 for batch in loader: event_seq, time_seq, cont_feats, cate_feats, sexes = batch event_seq = event_seq.to(device) time_seq = time_seq.to(device) keep, t_ctx, _ = select_context_indices( event_seq, time_seq, offset_years) if not keep.any(): continue n_total_eval += int(keep.sum().item()) event_seq = event_seq[keep] time_seq = time_seq[keep] t_ctx = t_ctx[keep] B, L = event_seq.shape b = torch.arange(B, device=device) t0 = time_seq[b, t_ctx] t1 = t0 + (float(tau_years) * DAYS_PER_YEAR) idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) in_window = ( (idxs > t_ctx.unsqueeze(1)) & (time_seq >= t0.unsqueeze(1)) & (time_seq <= t1.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) ) if not in_window.any(): continue b_idx, t_idx = in_window.nonzero(as_tuple=True) disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) # unique per (person, disease) to count per-person within-window occurrence key = b_idx.to(torch.long) * int(n_disease) + disease_ids uniq = torch.unique(key) uniq_disease = uniq % int(n_disease) counts.scatter_add_(0, uniq_disease, torch.ones_like( uniq_disease, dtype=torch.long)) return counts.detach().cpu().numpy(), int(n_total_eval) # ============================================================ # Evaluation core # ============================================================ def instantiate_model_and_head( cfg: Dict[str, Any], dataset: HealthDataset, device: str, checkpoint_path: str = "", ) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float], Dict[str, Any]]: model_type = str(cfg["model_type"]) loss_type = str(cfg["loss_type"]) loss_params: Dict[str, Any] = {} if loss_type == "exponential": out_dims = [dataset.n_disease] elif loss_type == "discrete_time_cif": bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) out_dims = [dataset.n_disease + 1, len(bin_edges)] elif loss_type == "lognormal_basis_binned_hazard_cif": centers = cfg.get("lognormal_centers", None) if centers is None: centers = cfg.get("centers", None) if not isinstance(centers, list) or len(centers) == 0: raise ValueError( "lognormal_basis_binned_hazard_cif requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json" ) r = len(centers) desired_total = int(dataset.n_disease) * int(r) legacy_total = 1 + desired_total # Prefer the new shape (K,R) but keep compatibility with older checkpoints # that used a single flattened dimension (1 + K*R). out_dims = [int(dataset.n_disease), int(r)] if checkpoint_path: try: ckpt = torch.load(checkpoint_path, map_location="cpu") head_sd = ckpt.get("head_state_dict", {}) w = head_sd.get("net.2.weight", None) if isinstance(w, torch.Tensor) and w.ndim == 2: out_features = int(w.shape[0]) if out_features == legacy_total: out_dims = [legacy_total] elif out_features == desired_total: out_dims = [int(dataset.n_disease), int(r)] else: raise ValueError( f"Checkpoint head out_features={out_features} does not match expected {desired_total} (K*R) or {legacy_total} (1+K*R)" ) except Exception as e: raise ValueError( f"Failed to infer head output dims from checkpoint={checkpoint_path}: {e}" ) loss_params["centers"] = centers loss_params["bandwidth_min"] = float(cfg.get("bandwidth_min", 1e-3)) loss_params["bandwidth_max"] = float(cfg.get("bandwidth_max", 10.0)) loss_params["bandwidth_init"] = float(cfg.get("bandwidth_init", 0.7)) loss_params["loss_eps"] = float(cfg.get("loss_eps", 1e-8)) loss_params["alpha_floor"] = float(cfg.get("alpha_floor", 0.0)) else: raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}") if model_type == "delphi_fork": backbone = DelphiFork( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, ).to(device) elif model_type == "sap_delphi": # Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path. emb_path = cfg.get("pretrained_emb_path", None) if emb_path in {"", None}: emb_path = cfg.get("pretrained_emd_path", None) if emb_path in {"", None}: run_dir = os.path.dirname(os.path.abspath( checkpoint_path)) if checkpoint_path else "" print( f"WARNING: SapDelphi pretrained embedding path missing in config " f"(expected 'pretrained_emb_path' or 'pretrained_emd_path'). " f"checkpoint={checkpoint_path} run_dir={run_dir}" ) backbone = SapDelphi( n_disease=dataset.n_disease, n_tech_tokens=2, n_embd=int(cfg["n_embd"]), n_head=int(cfg["n_head"]), n_layer=int(cfg["n_layer"]), pdrop=float(cfg.get("pdrop", 0.0)), age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), n_cont=dataset.n_cont, n_cate=dataset.n_cate, cate_dims=dataset.cate_dims, pretrained_weights_path=emb_path, freeze_embeddings=True, ).to(device) else: raise ValueError(f"Unsupported model_type: {model_type}") head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device) bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) return backbone, head, loss_type, bin_edges, loss_params @torch.no_grad() def predict_cifs_for_model( backbone: torch.nn.Module, head: torch.nn.Module, loss_type: str, bin_edges: Sequence[float], loss_params: Dict[str, Any], loader: DataLoader, device: str, offset_years: float, eval_horizons: Sequence[float], n_disease: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Run model and produce cause-specific, time-dependent CIF outputs. Returns: cif_full: (N, K, H) survival: (N, H) y_cause_within_tau: (N, K, H) NOTE: Evaluation is cause-specific and horizon-specific (multi-disease risk). """ backbone.eval() head.eval() # We will accumulate in CPU lists, then concat. cif_full_list: List[np.ndarray] = [] survival_list: List[np.ndarray] = [] y_cause_within_list: List[np.ndarray] = [] sex_list: List[np.ndarray] = [] for batch in loader: event_seq, time_seq, cont_feats, cate_feats, sexes = batch event_seq = event_seq.to(device) time_seq = time_seq.to(device) cont_feats = cont_feats.to(device) cate_feats = cate_feats.to(device) sexes = sexes.to(device) keep, t_ctx, _ = select_context_indices( event_seq, time_seq, offset_years) if not keep.any(): continue # filter batch event_seq = event_seq[keep] time_seq = time_seq[keep] cont_feats = cont_feats[keep] cate_feats = cate_feats[keep] sexes_k = sexes[keep] t_ctx = t_ctx[keep] h = backbone(event_seq, time_seq, sexes_k, cont_feats, cate_feats) # (B,L,D) b = torch.arange(h.size(0), device=device) c = h[b, t_ctx] # (B,D) logits = head(c) if loss_type == "exponential": cif_full, survival = cifs_from_exponential_logits( logits, eval_horizons, return_survival=True) # (B,K,H), (B,H) elif loss_type == "discrete_time_cif": cif_full, survival = cifs_from_discrete_time_logits( # (B,K,H), (B,H) logits, bin_edges, eval_horizons, return_survival=True) elif loss_type == "lognormal_basis_binned_hazard_cif": centers = loss_params.get("centers", None) sigma = loss_params.get("sigma", None) if centers is None or sigma is None: raise ValueError( "lognormal_basis_binned_hazard_cif requires loss_params['centers'] and loss_params['sigma']") cif_full, survival = cifs_from_lognormal_basis_binned_hazard_logits( logits, centers=centers, sigma=sigma, bin_edges=bin_edges, taus=eval_horizons, eps=float(loss_params.get("loss_eps", 1e-8)), alpha_floor=float(loss_params.get("alpha_floor", 0.0)), return_survival=True, ) else: raise ValueError(f"Unsupported loss_type: {loss_type}") # Within-horizon labels for all causes: disease k occurs within tau after context. y_within_full = torch.stack( [ multi_hot_ever_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau), n_disease=int(n_disease), ).to(torch.float32) for tau in eval_horizons ], dim=2, ) # (B,K,H) cif_full_list.append(cif_full.detach().cpu().numpy()) survival_list.append(survival.detach().cpu().numpy()) y_cause_within_list.append(y_within_full.detach().cpu().numpy()) sex_list.append(sexes_k.detach().cpu().numpy()) if not cif_full_list: raise RuntimeError( "No valid samples for evaluation (all batches filtered out by offset).") cif_full = np.concatenate(cif_full_list, axis=0) survival = np.concatenate(survival_list, axis=0) y_cause_within = np.concatenate(y_cause_within_list, axis=0) sex = np.concatenate( sex_list, axis=0) if sex_list else np.array([], dtype=int) return cif_full, survival, y_cause_within, sex def evaluate_one_model( model_name: str, cif_full: np.ndarray, y_cause_within_tau: np.ndarray, eval_horizons: Sequence[float], out_rows: List[Dict[str, Any]], calib_rows: List[Dict[str, Any]], calib_cause_ids: Optional[Sequence[int]], n_calib_bins: int = 10, metric_workers: int = 0, progress: str = "auto", ) -> None: """Compute per-cause metrics for ALL diseases. Notes: - Writes scalar metrics for all causes into out_rows. - Writes calibration-bin points only for calib_cause_ids (to keep outputs tractable). """ cif_full = np.asarray(cif_full, dtype=float) y_cause_within_tau = np.asarray(y_cause_within_tau, dtype=float) if cif_full.ndim != 3 or y_cause_within_tau.ndim != 3: raise ValueError( "Expected cif_full and y_cause_within_tau with shape (N, K, H)") if cif_full.shape != y_cause_within_tau.shape: raise ValueError( f"Shape mismatch: cif_full {cif_full.shape} vs y_cause_within_tau {y_cause_within_tau.shape}" ) N, K, H = cif_full.shape if H != len(eval_horizons): raise ValueError("H mismatch between cif_full and eval_horizons") calib_set = set(int(x) for x in calib_cause_ids) if calib_cause_ids is not None else set() workers = int(metric_workers) if workers <= 0: workers = int(min(8, os.cpu_count() or 1)) workers = max(1, workers) show_progress = _should_show_progress(progress) def _eval_chunk( *, tau: float, p_tau: np.ndarray, y_tau: np.ndarray, brier_by_cause: np.ndarray, cause_ids: np.ndarray, ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], int]: local_rows: List[Dict[str, Any]] = [] local_calib: List[Dict[str, Any]] = [] for cid in cause_ids.tolist(): p = p_tau[:, cid] y = y_tau[:, cid] local_rows.append( { "model_name": model_name, "metric_name": "cause_brier", "horizon": float(tau), "cause": int(cid), "value": float(brier_by_cause[cid]), "ci_low": "", "ci_high": "", } ) # ICI: compute bins only if we will export them. need_bins = (not calib_set) or (int(cid) in calib_set) if need_bins: cal = calibration_deciles(p, y, n_bins=n_calib_bins) ici = float(cal["ici"]) else: cal = None ici = calibration_ici_only(p, y, n_bins=n_calib_bins) local_rows.append( { "model_name": model_name, "metric_name": "cause_ici", "horizon": float(tau), "cause": int(cid), "value": float(ici), "ci_low": "", "ci_high": "", } ) # Secondary: discrimination via AUC at the same horizon (point estimate only). auc = roc_auc_rank(y, p) local_rows.append( { "model_name": model_name, "metric_name": "cause_auc", "horizon": float(tau), "cause": int(cid), "value": float(auc), "ci_low": "", "ci_high": "", } ) if need_bins and cal is not None: for binfo in cal.get("bins", []): local_calib.append( { "model_name": model_name, "task": "cause_k", "horizon": float(tau), "cause_id": int(cid), "bin_index": int(binfo["bin"]), "p_mean": float(binfo["p_mean"]), "y_mean": float(binfo["y_mean"]), "n_in_bin": int(binfo["n"]), } ) return local_rows, local_calib, int(cause_ids.size) # Cause-specific, time-dependent metrics per horizon. for h_i, tau in enumerate(eval_horizons): p_tau = cif_full[:, :, h_i] # (N, K) y_tau = y_cause_within_tau[:, :, h_i] # (N, K) # Vectorized Brier for speed. brier_by_cause = np.mean((p_tau - y_tau) ** 2, axis=0) # (K,) # Parallelize disease-level metrics; chunk to avoid millions of futures. all_ids = np.arange(int(K), dtype=int) chunks = np.array_split(all_ids, workers) done = 0 prefix = f"[{model_name}] tau={float(tau)}y " t0 = time.time() if workers <= 1: for ch in chunks: r_chunk, c_chunk, n_done = _eval_chunk( tau=float(tau), p_tau=p_tau, y_tau=y_tau, brier_by_cause=brier_by_cause, cause_ids=ch, ) out_rows.extend(r_chunk) calib_rows.extend(c_chunk) done += int(n_done) if show_progress: sys.stdout.write( "\r" + _progress_line(done, int(K), prefix=prefix)) sys.stdout.flush() else: with ThreadPoolExecutor(max_workers=workers) as ex: futs = [ ex.submit( _eval_chunk, tau=float(tau), p_tau=p_tau, y_tau=y_tau, brier_by_cause=brier_by_cause, cause_ids=ch, ) for ch in chunks if int(ch.size) > 0 ] for fut in as_completed(futs): r_chunk, c_chunk, n_done = fut.result() out_rows.extend(r_chunk) calib_rows.extend(c_chunk) done += int(n_done) if show_progress: sys.stdout.write( "\r" + _progress_line(done, int(K), prefix=prefix)) sys.stdout.flush() if show_progress: dt = time.time() - t0 sys.stdout.write("\r" + _progress_line(int(K), int(K), prefix=prefix) + f" ({dt:.1f}s)\n") sys.stdout.flush() def summarize_over_diseases( rows: List[Dict[str, Any]], *, model_name: str, eval_horizons: Sequence[float], metrics: Sequence[str] = ("cause_brier", "cause_ici", "cause_auc"), ) -> List[Dict[str, Any]]: """Summarize mean/median of each metric over diseases (per horizon).""" out: List[Dict[str, Any]] = [] # Build metric_name -> horizon -> list of values bucket: Dict[Tuple[str, float], List[float]] = {} for r in rows: if r.get("model_name") != model_name: continue m = str(r.get("metric_name")) if m not in set(metrics): continue h = _safe_float(r.get("horizon")) v = _safe_float(r.get("value")) if not np.isfinite(h): continue if not np.isfinite(v): continue bucket.setdefault((m, float(h)), []).append(float(v)) for tau in eval_horizons: ht = float(tau) for m in metrics: vals = bucket.get((str(m), ht), []) if vals: arr = np.asarray(vals, dtype=float) mean_v = float(np.mean(arr)) med_v = float(np.median(arr)) n_valid = int(arr.size) else: mean_v = float("nan") med_v = float("nan") n_valid = 0 out.append( { "model_name": str(model_name), "metric_name": str(m), "horizon": ht, "mean": mean_v, "median": med_v, "n_valid": n_valid, } ) return out def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None: fieldnames = [ "model_name", "task", "horizon", "cause_id", "bin_index", "p_mean", "y_mean", "n_in_bin", ] with open(path, "w", newline="") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for r in rows: w.writerow(r) def write_results_csv(path: str, rows: List[Dict[str, Any]]) -> None: fieldnames = [ "model_name", "metric_name", "horizon", "cause", "value", "ci_low", "ci_high", ] with open(path, "w", newline="") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for r in rows: w.writerow(r) def _make_eval_tag(split: str, offset_years: float) -> str: """Short tag for filenames written into run directories.""" off = f"{float(offset_years):.4f}".rstrip("0").rstrip(".") return f"{split}_offset{off}y" def main() -> int: ap = argparse.ArgumentParser( description="Unified downstream evaluation via CIFs") ap.add_argument("--models_json", type=str, required=True, help="Path to models list JSON") ap.add_argument("--data_prefix", type=str, default="ukb", help="Dataset prefix") ap.add_argument("--split", type=str, default="test", choices=["train", "val", "test", "all"], help="Which split to evaluate") ap.add_argument("--offset_years", type=float, default=0.5, help="Anti-leakage offset (years)") ap.add_argument("--eval_horizons", type=float, nargs="*", default=DEFAULT_EVAL_HORIZONS) ap.add_argument("--batch_size", type=int, default=128) ap.add_argument("--num_workers", type=int, default=0) ap.add_argument("--seed", type=int, default=123) ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") ap.add_argument("--out_csv", type=str, default="eval_summary.csv") ap.add_argument("--out_meta_json", type=str, default="eval_meta.json") # Integrity checks ap.add_argument("--integrity_strict", action="store_true", default=False) ap.add_argument("--integrity_tol", type=float, default=1e-6) # Speed/UX ap.add_argument( "--metric_workers", type=int, default=0, help="Threads for per-disease metrics (0=auto, 1=disable parallelism)", ) ap.add_argument( "--progress", type=str, default="auto", choices=["auto", "bar", "none"], help="Progress visualization during per-disease evaluation", ) # Export settings for user-facing experiments ap.add_argument("--export_dir", type=str, default="eval_exports") ap.add_argument("--death_cause_id", type=int, default=DEFAULT_DEATH_CAUSE_ID) ap.add_argument("--focus_k", type=int, default=5, help="Additional non-death causes to include") ap.add_argument("--capture_k_pcts", type=int, nargs="*", default=[1, 5, 10, 20]) ap.add_argument( "--capture_curve_max_pct", type=int, default=50, help="If >0, also export a dense capture curve for k=1..max_pct", ) # High-risk cause concentration (cross-cause prioritization) ap.add_argument( "--cause_concentration_topk", type=int, nargs="*", default=[5, 10, 20, 50], help="Top-K causes per individual for Event Rate@K (cross-cause prioritization)", ) ap.add_argument( "--cause_concentration_write_random_baseline", action="store_true", default=False, help="If set, also export a random-ranking baseline (expected Event Rate@K and Recall@K with an uncertainty range)", ) args = ap.parse_args() set_deterministic(args.seed) specs = load_models_json(args.models_json) if not specs: raise ValueError("No models provided") export_dir = str(args.export_dir) _ensure_dir(export_dir) cause_names = load_cause_names("labels.csv") # Determine top-K causes from the evaluation split only (model-agnostic), # aligned to time-dependent risk: occurrence within tau_max after context. first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path) cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [ "bmi", "smoking", "alcohol"] dataset_for_top = HealthDataset( data_prefix=args.data_prefix, covariate_list=cov_list) subset_for_top = build_eval_subset( dataset_for_top, train_ratio=float(first_cfg.get("train_ratio", 0.7)), val_ratio=float(first_cfg.get("val_ratio", 0.15)), seed=int(first_cfg.get("random_seed", 42)), split=args.split, ) loader_top = DataLoader( subset_for_top, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=health_collate_fn, ) tau_max = float(max(args.eval_horizons)) counts, n_total_eval = count_occurs_within_horizon( loader=loader_top, offset_years=args.offset_years, tau_years=tau_max, n_disease=dataset_for_top.n_disease, device=args.device, ) focus_causes = pick_focus_causes( counts_within_tau=counts, n_disease=int(dataset_for_top.n_disease), death_cause_id=int(args.death_cause_id), k=int(args.focus_k), ) top_cause_ids = np.asarray(focus_causes, dtype=int) # Export the chosen focus causes. focus_rows: List[Dict[str, Any]] = [] for r, cid in enumerate(focus_causes, start=1): row: Dict[str, Any] = {"cause": int(cid), "rank": int(r)} if cid in cause_names: row["cause_name"] = cause_names[cid] focus_rows.append(row) focus_fieldnames = ["cause", "rank"] + \ (["cause_name"] if any("cause_name" in r for r in focus_rows) else []) write_simple_csv(os.path.join(export_dir, "focus_causes.csv"), focus_fieldnames, focus_rows) # Metadata for focus causes (within tau_max). top_causes_meta: List[Dict[str, Any]] = [] for cid in focus_causes: n_case = int(counts[int(cid)]) if int( cid) < int(counts.shape[0]) else 0 top_causes_meta.append( { "cause_id": int(cid), "tau_years": float(tau_max), "n_case_within_tau": n_case, "n_control_within_tau": int(n_total_eval - n_case), "n_total_eval": int(n_total_eval), } ) summary_rows: List[Dict[str, Any]] = [] calib_rows: List[Dict[str, Any]] = [] # Experiment exports (accumulated across models) cap_points_rows: List[Dict[str, Any]] = [] cap_curve_rows: List[Dict[str, Any]] = [] conc_rows: List[Dict[str, Any]] = [] conc_base_rows: List[Dict[str, Any]] = [] # Track per-model integrity status for meta JSON. integrity_meta: Dict[str, Any] = {} # Evaluate each model for spec in specs: run_dir = os.path.dirname(os.path.abspath(spec.checkpoint_path)) tag = _make_eval_tag(args.split, float(args.offset_years)) # Remember list offsets so we can write per-model slices to the model's run_dir. calib_start = len(calib_rows) cfg = load_train_config_for_checkpoint(spec.checkpoint_path) # Identifiers for consistent exports model_id = str(spec.name) model_type = str( cfg.get("model_type", spec.model_type if hasattr(spec, "model_type") else "")) loss_type_id = str( cfg.get("loss_type", spec.loss_type if hasattr(spec, "loss_type") else "")) age_encoder = str(cfg.get("age_encoder", "")) cov_type = "full" if _parse_bool( cfg.get("full_cov", False)) else "partial" cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [ "bmi", "smoking", "alcohol"] dataset = HealthDataset( data_prefix=args.data_prefix, covariate_list=cov_list) subset = build_eval_subset( dataset, train_ratio=float(cfg.get("train_ratio", 0.7)), val_ratio=float(cfg.get("val_ratio", 0.15)), seed=int(cfg.get("random_seed", 42)), split=args.split, ) loader = DataLoader( subset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=health_collate_fn, ) ckpt = torch.load(spec.checkpoint_path, map_location=args.device) backbone, head, loss_type, bin_edges, loss_params = instantiate_model_and_head( cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path) backbone.load_state_dict(ckpt["model_state_dict"], strict=True) head.load_state_dict(ckpt["head_state_dict"], strict=True) if loss_type == "lognormal_basis_binned_hazard_cif": crit_state = ckpt.get("criterion_state_dict", {}) log_sigma = crit_state.get("log_sigma", None) if isinstance(log_sigma, torch.Tensor): log_sigma_t = log_sigma.to(device=args.device) sigma = torch.exp(log_sigma_t) else: sigma = torch.tensor(float(loss_params.get( "bandwidth_init", 0.7)), device=args.device) bmin = float(loss_params.get("bandwidth_min", 1e-3)) bmax = float(loss_params.get("bandwidth_max", 10.0)) sigma = torch.clamp(sigma, min=bmin, max=bmax) loss_params["sigma"] = sigma ( cif_full, survival, y_cause_within_tau, sex, ) = predict_cifs_for_model( backbone, head, loss_type, bin_edges, loss_params, loader, args.device, args.offset_years, args.eval_horizons, n_disease=int(dataset.n_disease), ) # CIF integrity checks before metrics. integrity_ok, integrity_notes = check_cif_integrity( cif_full, args.eval_horizons, tol=float(args.integrity_tol), name=spec.name, strict=bool(args.integrity_strict), survival=survival, ) integrity_meta[spec.name] = { "integrity_ok": bool(integrity_ok), "integrity_notes": integrity_notes, } # Per-disease metrics for ALL diseases (written into the model's run_dir). model_rows: List[Dict[str, Any]] = [] evaluate_one_model( model_name=spec.name, cif_full=cif_full, y_cause_within_tau=y_cause_within_tau, eval_horizons=args.eval_horizons, out_rows=model_rows, calib_rows=calib_rows, calib_cause_ids=top_cause_ids.tolist(), metric_workers=int(args.metric_workers), progress=str(args.progress), ) # Summary over diseases (mean/median per horizon). model_summary_rows = summarize_over_diseases( model_rows, model_name=spec.name, eval_horizons=args.eval_horizons, ) summary_rows.extend(model_summary_rows) # ============================================================ # Experiment: High-Risk Cause Concentration at fixed horizon # (cross-cause prioritization accuracy) # ============================================================ topk_causes = [int(x) for x in args.cause_concentration_topk] for sex_label, sex_mask in _sex_slices(sex if sex.size else None): for h_i, tau in enumerate(args.eval_horizons): p_tau_all = np.asarray(cif_full[:, :, h_i], dtype=float) y_tau_all = np.asarray( y_cause_within_tau[:, :, h_i], dtype=float) if sex_mask is not None: p_tau_all = p_tau_all[sex_mask] y_tau_all = y_tau_all[sex_mask] for rr in compute_event_rate_at_topk_causes(p_tau_all, y_tau_all, topk_causes): conc_rows.append( { "model_id": model_id, "model_type": model_type, "loss_type": loss_type_id, "age_encoder": age_encoder, "cov_type": cov_type, "horizon": float(tau), "sex": sex_label, **rr, } ) if bool(args.cause_concentration_write_random_baseline): for rr in compute_random_ranking_baseline_topk(y_tau_all, topk_causes): conc_base_rows.append( { "model_id": model_id, "model_type": model_type, "loss_type": loss_type_id, "age_encoder": age_encoder, "cov_type": cov_type, "horizon": float(tau), "sex": sex_label, **rr, } ) # Convenience slices for user-facing experiments (focus causes only). cause_cif_focus = cif_full[:, top_cause_ids, :] y_within_focus = y_cause_within_tau[:, top_cause_ids, :] # ============================================================ # Experiment 2: High-risk capture points (+ optional curve) # ============================================================ k_pcts = [int(x) for x in args.capture_k_pcts] curve_max = int(args.capture_curve_max_pct) curve_grid = list(range(1, curve_max + 1) ) if curve_max and curve_max > 0 else [] for sex_label, sex_mask in _sex_slices(sex if sex.size else None): for h_i, tau in enumerate(args.eval_horizons): for j, cause_id in enumerate(top_cause_ids.tolist()): p = cause_cif_focus[:, j, h_i] y = y_within_focus[:, j, h_i] if sex_mask is not None: p = p[sex_mask] y = y[sex_mask] for r in compute_capture_points(p, y, k_pcts): cap_points_rows.append( { "model_id": model_id, "model_type": model_type, "loss_type": loss_type_id, "age_encoder": age_encoder, "cov_type": cov_type, "cause": int(cause_id), "horizon": float(tau), "sex": sex_label, **r, } ) if curve_grid: for r in compute_capture_points(p, y, curve_grid): cap_curve_rows.append( { "model_id": model_id, "model_type": model_type, "loss_type": loss_type_id, "age_encoder": age_encoder, "cov_type": cov_type, "cause": int(cause_id), "horizon": float(tau), "sex": sex_label, **r, } ) # Optionally write top-cause counts into the main results CSV as metric rows. for tc in top_causes_meta: model_rows.append( { "model_name": spec.name, "metric_name": "topcause_n_case_within_tau", "horizon": float(tc["tau_years"]), "cause": int(tc["cause_id"]), "value": int(tc["n_case_within_tau"]), "ci_low": "", "ci_high": "", } ) model_rows.append( { "model_name": spec.name, "metric_name": "topcause_n_control_within_tau", "horizon": float(tc["tau_years"]), "cause": int(tc["cause_id"]), "value": int(tc["n_control_within_tau"]), "ci_low": "", "ci_high": "", } ) model_rows.append( { "model_name": spec.name, "metric_name": "topcause_n_total_eval", "horizon": float(tc["tau_years"]), "cause": int(tc["cause_id"]), "value": int(tc["n_total_eval"]), "ci_low": "", "ci_high": "", } ) # Write per-model results into the model's run directory. model_calib_rows = calib_rows[calib_start:] model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv") model_summary_csv = os.path.join(run_dir, f"eval_summary_{tag}.csv") model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv") model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json") write_results_csv(model_out_csv, model_rows) write_simple_csv( model_summary_csv, ["model_name", "metric_name", "horizon", "mean", "median", "n_valid"], model_summary_rows, ) write_calibration_bins_csv(model_calib_csv, model_calib_rows) model_meta = { "model_name": spec.name, "checkpoint_path": spec.checkpoint_path, "run_dir": run_dir, "split": args.split, "offset_years": args.offset_years, "eval_horizons": [float(x) for x in args.eval_horizons], "n_disease": int(dataset.n_disease), "top_cause_ids": top_cause_ids.tolist(), "top_causes": top_causes_meta, "integrity": {spec.name: integrity_meta.get(spec.name, {})}, "paths": { "results_csv": model_out_csv, "summary_csv": model_summary_csv, "calibration_bins_csv": model_calib_csv, }, } with open(model_meta_json, "w") as f: json.dump(model_meta, f, indent=2) print(f"Wrote per-model results to {model_out_csv}") # Write global summary (across diseases) across all models. write_simple_csv( args.out_csv, ["model_name", "metric_name", "horizon", "mean", "median", "n_valid"], summary_rows, ) # Write calibration curve points to a separate CSV. out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "." calib_csv_path = os.path.join(out_dir, "calibration_bins.csv") write_calibration_bins_csv(calib_csv_path, calib_rows) # Write experiment exports write_simple_csv( os.path.join(export_dir, "lift_capture_points.csv"), [ "model_id", "model_type", "loss_type", "age_encoder", "cov_type", "cause", "horizon", "sex", "k_pct", "n_targeted", "events_targeted", "events_total", "event_capture_rate", "precision_in_targeted", ], cap_points_rows, ) if cap_curve_rows: write_simple_csv( os.path.join(export_dir, "lift_capture_curve.csv"), [ "model_id", "model_type", "loss_type", "age_encoder", "cov_type", "cause", "horizon", "sex", "k_pct", "n_targeted", "events_targeted", "events_total", "event_capture_rate", "precision_in_targeted", ], cap_curve_rows, ) write_simple_csv( os.path.join(export_dir, "high_risk_cause_concentration.csv"), [ "model_id", "model_type", "loss_type", "age_encoder", "cov_type", "horizon", "sex", "topk", "event_rate_mean", "event_rate_median", "recall_mean", "recall_median", "n_total", "n_valid_recall", ], conc_rows, ) if conc_base_rows: write_simple_csv( os.path.join( export_dir, "high_risk_cause_concentration_random_baseline.csv"), [ "model_id", "model_type", "loss_type", "age_encoder", "cov_type", "horizon", "sex", "topk", "baseline_event_rate_mean", "baseline_event_rate_p05", "baseline_event_rate_p95", "baseline_recall_mean", "baseline_recall_p05", "baseline_recall_p95", "n_total", "n_valid_recall", "k_total", "baseline_method", ], conc_base_rows, ) # Manifest markdown (stable, user-facing) manifest_path = os.path.join(export_dir, "eval_exports_manifest.md") with open(manifest_path, "w", encoding="utf-8") as f: f.write( "# Evaluation Exports Manifest\n\n" "This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). " "All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n" "## Files\n\n" "- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n" "- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n" "- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n" "- high_risk_cause_concentration.csv: Event Rate@K and Recall@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart vs K.\n" "- high_risk_cause_concentration_random_baseline.csv (optional): Random-ranking baseline for Event Rate@K and Recall@K with an uncertainty range (enabled by --cause_concentration_write_random_baseline).\n" ) meta = { "split": args.split, "offset_years": args.offset_years, "eval_horizons": [float(x) for x in args.eval_horizons], "tau_max": float(tau_max), "n_disease": int(dataset_for_top.n_disease), "top_cause_ids": top_cause_ids.tolist(), "top_causes": top_causes_meta, "integrity": integrity_meta, "notes": { "label": "Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau])", "primary_metrics": "cause_brier (CIF-based) and cause_ici (calibration)", "secondary_metrics": "cause_auc (discrimination)", "exclusions": "No all-cause aggregation; no next-event formulation; ECE not reported", "warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.", "exports_dir": export_dir, "focus_causes": focus_causes, }, } with open(args.out_meta_json, "w") as f: json.dump(meta, f, indent=2) print(f"Wrote {args.out_csv} with {len(summary_rows)} rows") print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows") print(f"Wrote {args.out_meta_json}") return 0 if __name__ == "__main__": raise SystemExit(main())