Files
DeepHealth/evaluate_models.py

2413 lines
85 KiB
Python
Raw Normal View History

import argparse
import csv
import json
import math
import os
import random
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from dataset import HealthDataset, health_collate_fn
from model import DelphiFork, SapDelphi, SimpleHead
# ============================================================
# Constants / defaults (aligned with evaluate_prompt.md)
# ============================================================
DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")]
DEFAULT_EVAL_HORIZONS = [0.72, 1.61, 3.84, 10.0]
DAYS_PER_YEAR = 365.25
DEFAULT_DEATH_CAUSE_ID = 1256
# ============================================================
# Model specs
# ============================================================
@dataclass(frozen=True)
class ModelSpec:
name: str
model_type: str # delphi_fork | sap_delphi
loss_type: str # exponential | discrete_time_cif | lognormal_basis_binned_hazard_cif
full_cov: bool
checkpoint_path: str
# ============================================================
# Determinism
# ============================================================
def set_deterministic(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# ============================================================
# Utilities
# ============================================================
def _parse_bool(x: Any) -> bool:
if isinstance(x, bool):
return x
s = str(x).strip().lower()
if s in {"true", "1", "yes", "y"}:
return True
if s in {"false", "0", "no", "n"}:
return False
raise ValueError(f"Cannot parse boolean: {x!r}")
def load_models_json(path: str) -> List[ModelSpec]:
with open(path, "r") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("models_json must be a list of model entries")
specs: List[ModelSpec] = []
for row in data:
specs.append(
ModelSpec(
name=str(row["name"]),
model_type=str(row["model_type"]),
loss_type=str(row["loss_type"]),
full_cov=_parse_bool(row["full_cov"]),
checkpoint_path=str(row["checkpoint_path"]),
)
)
return specs
def load_train_config_for_checkpoint(checkpoint_path: str) -> Dict[str, Any]:
run_dir = os.path.dirname(os.path.abspath(checkpoint_path))
cfg_path = os.path.join(run_dir, "train_config.json")
with open(cfg_path, "r") as f:
cfg = json.load(f)
return cfg
def build_eval_subset(
dataset: HealthDataset,
train_ratio: float,
val_ratio: float,
seed: int,
split: str,
):
n_total = len(dataset)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
train_ds, val_ds, test_ds = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed),
)
if split == "train":
return train_ds
if split == "val":
return val_ds
if split == "test":
return test_ds
if split == "all":
return dataset
raise ValueError("split must be one of: train, val, test, all")
# ============================================================
# Context selection (anti-leakage)
# ============================================================
def select_context_indices(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
offset_years: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Select per-sample prediction context index.
IMPORTANT SEMANTICS:
- The last observed token time is treated as the FOLLOW-UP END time.
- We pick the last valid token with time <= (followup_end_time - offset).
- We do NOT interpret followup_end_time as an event time.
Returns:
keep_mask: (B,) bool, which samples have a valid context
t_ctx: (B,) long, index into sequence
t_ctx_time: (B,) float, time (days) at context
"""
# valid tokens are event != 0 (padding is 0)
valid = event_seq != 0
lengths = valid.sum(dim=1)
last_idx = torch.clamp(lengths - 1, min=0)
b = torch.arange(event_seq.size(0), device=event_seq.device)
followup_end_time = time_seq[b, last_idx]
t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR)
eligible = valid & (time_seq <= t_cut.unsqueeze(1))
eligible_counts = eligible.sum(dim=1)
keep = eligible_counts > 0
t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long)
t_ctx_time = time_seq[b, t_ctx]
return keep, t_ctx, t_ctx_time
def next_event_after_context(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return next disease event after context.
Returns:
dt_years: (B,) float, time to next disease in years; +inf if none
cause: (B,) long, disease id in [0,K) for next event; -1 if none
"""
B, L = event_seq.shape
b = torch.arange(B, device=event_seq.device)
t0 = time_seq[b, t_ctx]
# Allow same-day events while excluding the context token itself.
# We rely on time-sorted sequences and select the FIRST valid future event by index.
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
idx_min = torch.where(
future, idxs, torch.full_like(idxs, L)).min(dim=1).values
has = idx_min < L
t_next = torch.where(has, idx_min, torch.zeros_like(idx_min))
t_next_time = time_seq[b, t_next]
dt_days = t_next_time - t0
dt_years = dt_days / DAYS_PER_YEAR
dt_years = torch.where(
has, dt_years, torch.full_like(dt_years, float("inf")))
cause_token = event_seq[b, t_next]
cause = (cause_token - 2).to(torch.long)
cause = torch.where(has, cause, torch.full_like(cause, -1))
return dt_years, cause
def multi_hot_ever_within_horizon(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
tau_years: float,
n_disease: int,
) -> torch.Tensor:
"""Binary labels: disease k occurs within tau after context (any occurrence)."""
B, L = event_seq.shape
b = torch.arange(B, device=event_seq.device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (tau_years * DAYS_PER_YEAR)
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
# Include same-day events after context, exclude any token at/before context index.
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
if not in_window.any():
return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
y[b_idx, disease_ids] = True
return y
def multi_hot_ever_after_context_anytime(
event_seq: torch.Tensor,
t_ctx: torch.Tensor,
n_disease: int,
) -> torch.Tensor:
"""Binary labels: disease k occurs ANYTIME after the prediction context.
This is Delphi2M-compatible for Task A case/control definition.
Same-day events are included as long as they occur after the context token index.
"""
B, L = event_seq.shape
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
if not future.any():
return y
b_idx, t_idx = future.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
y[b_idx, disease_ids] = True
return y
def multi_hot_selected_causes_within_horizon(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
tau_years: float,
cause_ids: torch.Tensor,
n_disease: int,
) -> torch.Tensor:
"""Labels for selected causes only: does cause k occur within tau after context?"""
B, L = event_seq.shape
device = event_seq.device
b = torch.arange(B, device=device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (tau_years * DAYS_PER_YEAR)
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device)
if not in_window.any():
return out
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
# Filter to selected causes via a boolean membership mask over the global disease space.
selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device)
selected[cause_ids] = True
keep = selected[disease_ids]
if not keep.any():
return out
b_idx = b_idx[keep]
disease_ids = disease_ids[keep]
# Map disease_id -> local index in cause_ids
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device)
lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device)
local = lookup[disease_ids]
out[b_idx, local] = True
return out
# ============================================================
# CIF conversion
# ============================================================
def cifs_from_exponential_logits(
logits: torch.Tensor,
taus: Sequence[float],
eps: float = 1e-6,
return_survival: bool = False,
) -> torch.Tensor:
"""Convert exponential cause-specific logits -> CIFs at taus.
logits: (B, K)
returns: (B, K, H) or (cif, survival) if return_survival
"""
hazards = F.softplus(logits) + eps
total = hazards.sum(dim=1, keepdim=True) # (B,1)
taus_t = torch.tensor(list(taus), device=logits.device,
dtype=hazards.dtype).view(1, 1, -1)
total_h = total.unsqueeze(-1) # (B,1,1)
# (1 - exp(-Lambda * tau))
one_minus_surv = 1.0 - torch.exp(-total_h * taus_t)
frac = hazards / torch.clamp(total, min=eps)
cif = frac.unsqueeze(-1) * one_minus_surv # (B,K,H)
# If total==0, set to 0
cif = torch.where(total_h > 0, cif, torch.zeros_like(cif))
if not return_survival:
return cif
survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H)
# Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors.
nonzero = (total.squeeze(1) > 0).unsqueeze(1)
survival = torch.where(nonzero, survival, torch.ones_like(survival))
return cif, survival
def cifs_from_discrete_time_logits(
logits: torch.Tensor,
bin_edges: Sequence[float],
taus: Sequence[float],
return_survival: bool = False,
) -> torch.Tensor:
"""Convert discrete-time CIF logits -> CIFs at taus.
logits: (B, K+1, n_bins+1)
bin_edges: len=n_bins+1 (including 0 and inf)
taus: subset of finite bin edges (recommended)
returns: (B, K, H) or (cif, survival) if return_survival
"""
if logits.ndim != 3:
raise ValueError("Expected logits shape (B, K+1, n_bins+1)")
B, K_plus_1, n_bins_plus_1 = logits.shape
K = K_plus_1 - 1
edges = [float(x) for x in bin_edges]
# drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf
finite_edges = [e for e in edges[1:] if math.isfinite(e)]
n_bins = len(finite_edges)
if n_bins_plus_1 != len(edges):
raise ValueError("logits last dim must match len(bin_edges)")
probs = torch.softmax(logits, dim=1) # (B, K+1, n_bins+1)
# use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot)
hazards = probs[:, :K, 1: 1 + n_bins] # (B,K,n_bins)
p_comp = probs[:, K, 1: 1 + n_bins] # (B,n_bins)
# survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v]
ones = torch.ones((B, 1), device=logits.device, dtype=probs.dtype)
cum = torch.cumprod(p_comp, dim=1)
s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # (B, n_bins)
cif_bins = torch.cumsum(s_prev.unsqueeze(
1) * hazards, dim=2) # (B,K,n_bins)
# Robust mapping from tau -> edge index (floating-point safe).
# taus are expected to align with bin edges, but may differ slightly due to parsing/serialization.
finite_edges_arr = np.asarray(finite_edges, dtype=float)
tau_to_idx: List[int] = []
for tau in taus:
tau_f = float(tau)
if not math.isfinite(tau_f):
raise ValueError("taus must be finite for discrete-time CIF")
diffs = np.abs(finite_edges_arr - tau_f)
j = int(np.argmin(diffs))
if diffs[j] > 1e-6:
raise ValueError(
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})"
)
tau_to_idx.append(j)
idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long)
cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H)
if not return_survival:
return cif
# Survival at each horizon = prod_{u <= idx[h]} p_comp[u]
survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v]
survival = survival_bins.index_select(dim=1, index=idx) # (B,H)
return cif, survival
def _normal_cdf_stable(z: torch.Tensor) -> torch.Tensor:
z = torch.clamp(z, -12.0, 12.0)
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
def cifs_from_lognormal_basis_binned_hazard_logits(
logits: torch.Tensor,
*,
centers: Sequence[float],
sigma: torch.Tensor,
bin_edges: Sequence[float],
taus: Sequence[float],
eps: float = 1e-8,
alpha_floor: float = 0.0,
return_survival: bool = False,
) -> torch.Tensor:
"""Convert Route-3 binned hazard logits -> CIFs at taus.
logits: (B, J, R) OR (B, J*R) OR (B, 1+J*R) (leading column ignored).
taus are expected to align with finite bin edges.
"""
if logits.ndim not in {2, 3}:
raise ValueError("logits must be 2D or 3D")
if sigma.ndim != 0:
raise ValueError("sigma must be a scalar tensor")
device = logits.device
dtype = logits.dtype
centers_t = torch.tensor([float(x)
for x in centers], device=device, dtype=dtype)
r = int(centers_t.numel())
if r <= 0:
raise ValueError("centers must be non-empty")
offset = 0
if logits.ndim == 3:
j = int(logits.shape[1])
if int(logits.shape[2]) != r:
raise ValueError(
f"logits.shape[2] must equal R={r}; got {int(logits.shape[2])}"
)
else:
d = int(logits.shape[1])
if d % r == 0:
jr = d
elif (d - 1) % r == 0:
offset = 1
jr = d - 1
else:
raise ValueError(
f"logits.shape[1] must be divisible by R={r} (or 1+J*R); got {d}")
j = jr // r
if j <= 0:
raise ValueError("Inferred J must be >= 1")
edges = [float(x) for x in bin_edges]
finite_edges = [e for e in edges[1:] if math.isfinite(e)]
n_bins = len(finite_edges)
if n_bins <= 0:
raise ValueError("bin_edges must contain at least one finite edge")
# Build finite bins [edges[k-1], edges[k]) for k=1..n_bins
left = torch.tensor(edges[:n_bins], device=device, dtype=dtype)
right = torch.tensor(edges[1:1 + n_bins], device=device, dtype=dtype)
# Stable t_min clamp (aligns with training loss rule).
t_min = 1e-12
if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0:
t_min = edges[1] * 1e-6
t_min_t = torch.tensor(float(t_min), device=device, dtype=dtype)
left_is_zero = left <= 0
left_clamped = torch.clamp(left, min=t_min_t)
log_left = torch.log(left_clamped)
right_clamped = torch.clamp(right, min=t_min_t)
log_right = torch.log(right_clamped)
sigma_c = sigma.to(device=device, dtype=dtype)
z_left = (log_left.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c
z_right = (log_right.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c
cdf_left = _normal_cdf_stable(z_left)
if left_is_zero.any():
cdf_left = torch.where(left_is_zero.unsqueeze(-1),
torch.zeros_like(cdf_left), cdf_left)
cdf_right = _normal_cdf_stable(z_right)
delta_basis = torch.clamp(cdf_right - cdf_left, min=0.0) # (n_bins, R)
if logits.ndim == 3:
alpha = F.softplus(logits) + float(alpha_floor) # (B,J,R)
else:
logits_used = logits[:, offset:]
alpha = (F.softplus(logits_used) + float(alpha_floor)
).view(logits.size(0), j, r) # (B,J,R)
h_jk = torch.einsum("bjr,kr->bjk", alpha, delta_basis) # (B,J,n_bins)
h_k = h_jk.sum(dim=1) # (B,n_bins)
h_k = torch.clamp(h_k, min=eps)
h_jk = torch.clamp(h_jk, min=eps)
p_comp = torch.exp(-h_k) # (B,n_bins)
one_minus = -torch.expm1(-h_k) # (B,n_bins) = 1-exp(-H)
ratio = h_jk / torch.clamp(h_k.unsqueeze(1), min=eps)
p_event = one_minus.unsqueeze(1) * ratio # (B,J,n_bins)
ones = torch.ones((alpha.size(0), 1), device=device, dtype=dtype)
cum = torch.cumprod(p_comp, dim=1) # survival after each bin
s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # survival before each bin
cif_bins = torch.cumsum(s_prev.unsqueeze(
1) * p_event, dim=2) # (B,J,n_bins)
finite_edges_arr = np.asarray(finite_edges, dtype=float)
tau_to_idx: List[int] = []
for tau in taus:
tau_f = float(tau)
if not math.isfinite(tau_f):
raise ValueError("taus must be finite for discrete-time CIF")
diffs = np.abs(finite_edges_arr - tau_f)
idx0 = int(np.argmin(diffs))
if diffs[idx0] > 1e-6:
raise ValueError(
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[idx0]})"
)
tau_to_idx.append(idx0)
idx = torch.tensor(tau_to_idx, device=device, dtype=torch.long)
cif = cif_bins.index_select(dim=2, index=idx) # (B,J,H)
if not return_survival:
return cif
survival = cum.index_select(dim=1, index=idx) # (B,H)
return cif, survival
# ============================================================
# CIF integrity checks
# ============================================================
def check_cif_integrity(
cause_cif: np.ndarray,
horizons: Sequence[float],
*,
tol: float = 1e-6,
name: str = "",
strict: bool = False,
survival: Optional[np.ndarray] = None,
) -> Tuple[bool, List[str]]:
"""Run basic sanity checks on CIF arrays.
Args:
cause_cif: (N, K, H)
horizons: length H
tol: tolerance for inequalities
name: model name for messages
strict: if True, raise ValueError on first failure
survival: optional (N, H) survival values at the same horizons
Returns:
(integrity_ok, notes)
"""
notes: List[str] = []
model_tag = f"[{name}] " if name else ""
def _fail(msg: str) -> None:
full = model_tag + msg
if strict:
raise ValueError(full)
print("WARNING:", full)
notes.append(msg)
cif = np.asarray(cause_cif)
if cif.ndim != 3:
_fail(f"integrity: expected cause_cif ndim=3, got {cif.ndim}")
return False, notes
N, K, H = cif.shape
if H != len(horizons):
_fail(
f"integrity: horizon length mismatch (H={H}, len(horizons)={len(horizons)})")
# (5) Finite
if not np.isfinite(cif).all():
_fail("integrity: non-finite values (NaN/Inf) in cause_cif")
# (1) Range
cmin = float(np.nanmin(cif))
cmax = float(np.nanmax(cif))
if cmin < -tol:
_fail(f"integrity: range min={cmin} < -tol={-tol}")
if cmax > 1.0 + tol:
_fail(f"integrity: range max={cmax} > 1+tol={1.0+tol}")
# (2) Monotonicity in horizons (per n,k)
diffs = np.diff(cif, axis=2)
if diffs.size > 0:
if np.nanmin(diffs) < -tol:
_fail("integrity: monotonicity violated (found negative diff along horizons)")
# (3) Probability mass: sum_k CIF <= 1
mass = np.sum(cif, axis=1) # (N,H)
mass_max = float(np.nanmax(mass))
if mass_max > 1.0 + tol:
_fail(f"integrity: probability mass exceeds 1 (max sum_k={mass_max})")
# (4) Conservation with survival, if provided
if survival is None:
warn = "integrity: survival not provided; skipping conservation check"
if strict:
# still skip (requested behavior), but keep message for context
notes.append(warn)
else:
print("WARNING:", model_tag + warn)
notes.append(warn)
else:
s = np.asarray(survival, dtype=float)
if s.shape != (N, H):
_fail(
f"integrity: survival shape mismatch (got {s.shape}, expected {(N, H)})")
else:
recon = 1.0 - s
err = np.abs(recon - mass)
# Discrete-time should be very tight; exponential may accumulate slightly more numerical error.
tol_cons = max(float(tol), 1e-4)
if float(np.nanmax(err)) > tol_cons:
_fail(
f"integrity: conservation violated (max |(1-surv)-sum_cif|={float(np.nanmax(err))}, tol={tol_cons})")
ok = len([n for n in notes if not n.endswith(
"skipping conservation check")]) == 0
return ok, notes
# ============================================================
# Metrics
# ============================================================
# --- Rank-based ROC AUC (ties handled via midranks) ---
def compute_midrank(x: np.ndarray) -> np.ndarray:
"""Vectorized midrank computation (ties -> average ranks)."""
x = np.asarray(x, dtype=float)
n = int(x.shape[0])
if n == 0:
return np.asarray([], dtype=float)
order = np.argsort(x, kind="mergesort")
z = x[order]
# Find tie groups in sorted order.
diff = np.diff(z)
# boundaries includes 0 and n
boundaries = np.concatenate(
[np.array([0], dtype=int), np.nonzero(diff != 0)
[0] + 1, np.array([n], dtype=int)]
)
starts = boundaries[:-1]
ends = boundaries[1:]
lens = ends - starts
# Midrank for each group in 1-based rank space.
mids = 0.5 * (starts + ends - 1) + 1.0
t_sorted = np.repeat(mids, lens).astype(float, copy=False)
out = np.empty(n, dtype=float)
out[order] = t_sorted
return out
def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Rank-based ROC AUC via MannWhitney U statistic (ties handled by midranks).
Returns NaN for degenerate labels.
"""
y = np.asarray(y_true, dtype=int)
s = np.asarray(y_score, dtype=float)
if y.ndim != 1 or s.ndim != 1 or y.shape[0] != s.shape[0]:
raise ValueError("roc_auc_rank expects 1D arrays of equal length")
m = int(np.sum(y == 1))
n = int(np.sum(y == 0))
if m == 0 or n == 0:
return float("nan")
ranks = compute_midrank(s)
sum_pos = float(np.sum(ranks[y == 1]))
auc = (sum_pos - m * (m + 1) / 2.0) / (m * n)
return float(auc)
def brier_score(p: np.ndarray, y: np.ndarray) -> float:
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
return float(np.mean((p - y) ** 2))
def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]:
bins, ici = _calibration_bins_and_ici(
p, y, n_bins=int(n_bins), return_bins=True)
return {"bins": bins, "ici": float(ici)}
def calibration_ici_only(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> float:
"""Fast ICI only (no per-bin point export)."""
_, ici = _calibration_bins_and_ici(
p, y, n_bins=int(n_bins), return_bins=False)
return float(ici)
def _calibration_bins_and_ici(
p: np.ndarray,
y: np.ndarray,
*,
n_bins: int,
return_bins: bool,
) -> Tuple[List[Dict[str, Any]], float]:
"""Vectorized quantile binning for calibration + ICI."""
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
if p.size == 0:
return ([], float("nan")) if return_bins else ([], float("nan"))
q = np.linspace(0.0, 1.0, int(n_bins) + 1)
edges = np.quantile(p, q)
edges = np.asarray(edges, dtype=float)
edges[0] = -np.inf
edges[-1] = np.inf
# Bin assignment: i if edges[i] < p <= edges[i+1]
bin_idx = np.searchsorted(edges, p, side="right") - 1
bin_idx = np.clip(bin_idx, 0, int(n_bins) - 1)
counts = np.bincount(bin_idx, minlength=int(n_bins)).astype(float)
sum_p = np.bincount(bin_idx, weights=p,
minlength=int(n_bins)).astype(float)
sum_y = np.bincount(bin_idx, weights=y,
minlength=int(n_bins)).astype(float)
nonempty = counts > 0
if not np.any(nonempty):
return ([], float("nan")) if return_bins else ([], float("nan"))
p_mean = np.zeros(int(n_bins), dtype=float)
y_mean = np.zeros(int(n_bins), dtype=float)
p_mean[nonempty] = sum_p[nonempty] / counts[nonempty]
y_mean[nonempty] = sum_y[nonempty] / counts[nonempty]
diffs = np.abs(p_mean[nonempty] - y_mean[nonempty])
ici = float(np.mean(diffs)) if diffs.size else float("nan")
if not return_bins:
return [], ici
bins: List[Dict[str, Any]] = []
idxs = np.nonzero(nonempty)[0]
for i in idxs.tolist():
bins.append(
{
"bin": int(i),
"p_mean": float(p_mean[i]),
"y_mean": float(y_mean[i]),
"n": int(counts[i]),
}
)
return bins, ici
def _progress_line(done: int, total: int, prefix: str = "") -> str:
total_i = max(1, int(total))
done_i = max(0, min(int(done), total_i))
width = 28
frac = done_i / total_i
filled = int(round(width * frac))
bar = "#" * filled + "-" * (width - filled)
pct = 100.0 * frac
return f"{prefix}[{bar}] {done_i}/{total_i} ({pct:5.1f}%)"
def _should_show_progress(mode: str) -> bool:
m = str(mode).strip().lower()
if m in {"0", "false", "no", "none", "off"}:
return False
# Default: show if interactive.
if m in {"auto", "1", "true", "yes", "on", "bar"}:
try:
return bool(sys.stdout.isatty())
except Exception:
return True
return True
def _safe_float(x: Any, default: float = float("nan")) -> float:
try:
return float(x)
except Exception:
return float(default)
def _ensure_dir(path: str) -> None:
os.makedirs(path, exist_ok=True)
def load_cause_names(path: str = "labels.csv") -> Dict[int, str]:
"""Load 0-based cause_id -> name mapping.
labels.csv is assumed to be one label per line, in disease-id order.
"""
if not os.path.exists(path):
return {}
mapping: Dict[int, str] = {}
with open(path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
name = line.strip()
if name:
mapping[int(i)] = name
return mapping
def pick_focus_causes(
*,
counts_within_tau: Optional[np.ndarray],
n_disease: int,
death_cause_id: int = DEFAULT_DEATH_CAUSE_ID,
k: int = 5,
) -> List[int]:
"""Pick focus causes for user-facing evaluation.
Rule:
1) Always include death_cause_id first.
2) Then add K additional causes by descending event count if available.
If counts_within_tau is None, fall back to descending cause_id coverage proxy.
Notes:
- counts_within_tau is expected to be shape (n_disease,).
- Deterministic: ties broken by smaller cause id.
"""
n_disease_i = int(n_disease)
if death_cause_id < 0 or death_cause_id >= n_disease_i:
print(
f"WARNING: death_cause_id={death_cause_id} out of range (n_disease={n_disease_i}); "
"it will be omitted from focus causes."
)
focus: List[int] = []
else:
focus = [int(death_cause_id)]
candidates = [i for i in range(n_disease_i) if i != int(death_cause_id)]
if counts_within_tau is not None:
c = np.asarray(counts_within_tau).astype(float)
if c.shape[0] != n_disease_i:
print(
"WARNING: counts_within_tau length mismatch; falling back to coverage proxy ordering."
)
counts_within_tau = None
else:
# Sort by (-count, cause_id)
order = sorted(candidates, key=lambda i: (-float(c[i]), int(i)))
order = [i for i in order if float(c[i]) > 0]
focus.extend([int(i) for i in order[: int(k)]])
if counts_within_tau is None:
# Fallback: deterministic coverage proxy (descending id, excluding death), then take K.
# (Real coverage requires data; this path is mostly for robustness.)
order = sorted(candidates, key=lambda i: (-int(i)))
focus.extend([int(i) for i in order[: int(k)]])
# De-dup while preserving order
seen = set()
out: List[int] = []
for cid in focus:
if cid not in seen:
out.append(cid)
seen.add(cid)
return out
def write_simple_csv(path: str, fieldnames: List[str], rows: List[Dict[str, Any]]) -> None:
_ensure_dir(os.path.dirname(os.path.abspath(path)) or ".")
with open(path, "w", newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow(r)
def _sex_slices(sex: Optional[np.ndarray]) -> List[Tuple[str, Optional[np.ndarray]]]:
"""Return list of (sex_label, mask) slices including an 'all' slice.
If sex is missing, returns only ('all', None).
"""
out: List[Tuple[str, Optional[np.ndarray]]] = [("all", None)]
if sex is None:
return out
s = np.asarray(sex)
if s.ndim != 1:
return out
for val in [0, 1]:
m = (s == val)
if int(np.sum(m)) > 0:
out.append((str(val), m))
return out
def _quantile_edges(p: np.ndarray, q: int) -> np.ndarray:
edges = np.quantile(p, np.linspace(0.0, 1.0, int(q) + 1))
edges = np.asarray(edges, dtype=float)
edges[0] = -np.inf
edges[-1] = np.inf
return edges
def compute_risk_stratification_bins(
p: np.ndarray,
y: np.ndarray,
*,
q_default: int = 10,
) -> Tuple[int, List[Dict[str, Any]], Dict[str, Any]]:
"""Compute quantile-based risk strata and a compact summary."""
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
n = int(p.shape[0])
if n == 0:
return 0, [], {
"y_overall": float("nan"),
"top_decile_y_rate": float("nan"),
"bottom_half_y_rate": float("nan"),
"lift_top10_vs_bottom50": float("nan"),
"slope_pred_vs_obs": float("nan"),
}
# Choose quantiles robustly.
q = int(q_default)
if n < 200:
q = 5
edges = _quantile_edges(p, q)
y_overall = float(np.mean(y))
bin_rows: List[Dict[str, Any]] = []
p_means: List[float] = []
y_rates: List[float] = []
n_bins: List[int] = []
for i in range(q):
mask = (p > edges[i]) & (p <= edges[i + 1])
nb = int(np.sum(mask))
if nb == 0:
# Keep the row for consistent plotting; set NaNs.
bin_rows.append(
{
"q": int(i + 1),
"n_bin": 0,
"p_mean": float("nan"),
"y_rate": float("nan"),
"y_overall": y_overall,
"lift_vs_overall": float("nan"),
}
)
continue
p_mean = float(np.mean(p[mask]))
y_rate = float(np.mean(y[mask]))
lift = float(y_rate / y_overall) if y_overall > 0 else float("nan")
bin_rows.append(
{
"q": int(i + 1),
"n_bin": nb,
"p_mean": p_mean,
"y_rate": y_rate,
"y_overall": y_overall,
"lift_vs_overall": lift,
}
)
p_means.append(p_mean)
y_rates.append(y_rate)
n_bins.append(nb)
# Summary
top_mask = (p > edges[q - 1]) & (p <= edges[q])
bot_half_mask = (p > edges[0]) & (p <= edges[q // 2])
top_y = float(np.mean(y[top_mask])) if int(
np.sum(top_mask)) > 0 else float("nan")
bot_y = float(np.mean(y[bot_half_mask])) if int(
np.sum(bot_half_mask)) > 0 else float("nan")
lift_top_vs_bottom = float(top_y / bot_y) if (np.isfinite(top_y)
and np.isfinite(bot_y) and bot_y > 0) else float("nan")
slope = float("nan")
if len(p_means) >= 2:
# Weighted least squares slope of y_rate ~ p_mean.
x = np.asarray(p_means, dtype=float)
yy = np.asarray(y_rates, dtype=float)
w = np.asarray(n_bins, dtype=float)
xm = float(np.average(x, weights=w))
ym = float(np.average(yy, weights=w))
denom = float(np.sum(w * (x - xm) ** 2))
if denom > 0:
slope = float(np.sum(w * (x - xm) * (yy - ym)) / denom)
summary = {
"y_overall": y_overall,
"top_decile_y_rate": top_y,
"bottom_half_y_rate": bot_y,
"lift_top10_vs_bottom50": lift_top_vs_bottom,
"slope_pred_vs_obs": slope,
}
return q, bin_rows, summary
def compute_capture_points(
p: np.ndarray,
y: np.ndarray,
k_pcts: Sequence[int],
) -> List[Dict[str, Any]]:
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
n = int(p.shape[0])
if n == 0:
return []
order = np.argsort(-p)
y_sorted = y[order]
events_total = float(np.sum(y_sorted))
rows: List[Dict[str, Any]] = []
for k in k_pcts:
kf = float(k)
n_targeted = int(math.ceil(n * kf / 100.0))
n_targeted = max(1, min(n_targeted, n))
events_targeted = float(np.sum(y_sorted[:n_targeted]))
capture = float(events_targeted /
events_total) if events_total > 0 else float("nan")
precision = float(events_targeted / float(n_targeted))
rows.append(
{
"k_pct": int(k),
"n_targeted": int(n_targeted),
"events_targeted": float(events_targeted),
"events_total": float(events_total),
"event_capture_rate": capture,
"precision_in_targeted": precision,
}
)
return rows
def compute_event_rate_at_topk_causes(
p_tau: np.ndarray,
y_tau: np.ndarray,
topk_list: Sequence[int],
) -> List[Dict[str, Any]]:
"""Compute Event Rate@K for cross-cause prioritization.
For each individual, rank causes by predicted risk p_tau at a fixed horizon.
For each K, select top-K causes and compute the fraction that occur within the horizon.
Args:
p_tau: (N, K) predicted CIFs at a fixed horizon
y_tau: (N, K) binary labels (0/1) whether cause occurs within the horizon
topk_list: list of K values to evaluate
Returns:
List of rows with:
- topk
- event_rate_mean / event_rate_median
- recall_mean / recall_median (averaged over individuals with >=1 true cause)
- n_total / n_valid_recall
"""
p = np.asarray(p_tau, dtype=float)
y = np.asarray(y_tau, dtype=float)
if p.ndim != 2 or y.ndim != 2 or p.shape != y.shape:
raise ValueError(
"compute_event_rate_at_topk_causes expects (N,K) arrays of equal shape")
n, k_total = p.shape
if n == 0 or k_total == 0:
out: List[Dict[str, Any]] = []
for kk in topk_list:
out.append(
{
"topk": int(max(1, int(kk))),
"event_rate_mean": float("nan"),
"event_rate_median": float("nan"),
"recall_mean": float("nan"),
"recall_median": float("nan"),
"n_total": int(n),
"n_valid_recall": 0,
}
)
return out
# Sanitize K list.
topks = sorted({int(x) for x in topk_list if int(x) > 0})
if not topks:
return []
max_k = min(int(max(topks)), int(k_total))
if max_k <= 0:
return []
# Efficient: get top max_k causes per individual, then sort within those.
part = np.argpartition(-p, kth=max_k - 1, axis=1)[:, :max_k] # (N, max_k)
p_part = np.take_along_axis(p, part, axis=1)
order = np.argsort(-p_part, axis=1)
top_sorted = np.take_along_axis(part, order, axis=1) # (N, max_k)
out_rows: List[Dict[str, Any]] = []
for kk in topks:
kk_eff = min(int(kk), int(k_total))
idx = top_sorted[:, :kk_eff]
y_sel = np.take_along_axis(y, idx, axis=1)
# Selected true causes per person
hit = np.sum(y_sel, axis=1)
# Precision-like: fraction of selected causes that occur
per_person = hit / \
float(kk_eff) if kk_eff > 0 else np.full((n,), np.nan)
# Recall@K: fraction of true causes covered by top-K (undefined when no true cause)
g = np.sum(y, axis=1)
valid = g > 0
recall = np.full((n,), np.nan, dtype=float)
recall[valid] = hit[valid] / g[valid]
out_rows.append(
{
"topk": int(kk_eff),
"event_rate_mean": float(np.mean(per_person)) if per_person.size else float("nan"),
"event_rate_median": float(np.median(per_person)) if per_person.size else float("nan"),
"recall_mean": float(np.nanmean(recall)) if int(np.sum(valid)) > 0 else float("nan"),
"recall_median": float(np.nanmedian(recall)) if int(np.sum(valid)) > 0 else float("nan"),
"n_total": int(n),
"n_valid_recall": int(np.sum(valid)),
}
)
return out_rows
def compute_random_ranking_baseline_topk(
y_tau: np.ndarray,
topk_list: Sequence[int],
*,
z: float = 1.645,
) -> List[Dict[str, Any]]:
"""Random ranking baseline for Event Rate@K and Recall@K.
Baseline definition:
- For each individual, pick K causes uniformly at random without replacement.
- EventRate@K = (# selected causes that occur) / K.
- Recall@K = (# selected causes that occur) / (# causes that occur), averaged over individuals with >=1 true cause.
This function computes the expected baseline mean and an approximate 5-95% range
for the population mean using a normal approximation of the hypergeometric variance.
Args:
y_tau: (N, K_total) binary labels
topk_list: K values
z: z-score for the central interval; z=1.645 corresponds to ~90% (5-95%)
Returns:
Rows with baseline means and p05/p95 for both metrics.
"""
y = np.asarray(y_tau, dtype=float)
if y.ndim != 2:
raise ValueError(
"compute_random_ranking_baseline_topk expects y_tau with shape (N,K)")
n, k_total = y.shape
topks = sorted({int(x) for x in topk_list if int(x) > 0})
if not topks:
return []
g = np.sum(y, axis=1) # (N,)
valid = g > 0
n_valid = int(np.sum(valid))
out: List[Dict[str, Any]] = []
for kk in topks:
kk_eff = min(int(kk), int(k_total)) if k_total > 0 else int(kk)
if n == 0 or k_total == 0 or kk_eff <= 0:
out.append(
{
"topk": int(max(1, kk_eff)),
"baseline_event_rate_mean": float("nan"),
"baseline_event_rate_p05": float("nan"),
"baseline_event_rate_p95": float("nan"),
"baseline_recall_mean": float("nan"),
"baseline_recall_p05": float("nan"),
"baseline_recall_p95": float("nan"),
"n_total": int(n),
"n_valid_recall": int(n_valid),
"k_total": int(k_total),
"baseline_method": "random_ranking_hypergeometric_normal_approx",
}
)
continue
# Expected EventRate@K per person is E[X]/K = (K * (g/K_total))/K = g/K_total.
er_mean = float(np.mean(g / float(k_total)))
# Variance of hypergeometric count X:
# Var(X) = K * p * (1-p) * ((K_total - K)/(K_total - 1)), where p=g/K_total.
if k_total > 1 and kk_eff < k_total:
p = g / float(k_total)
finite_corr = (float(k_total - kk_eff) / float(k_total - 1))
var_x = float(kk_eff) * p * (1.0 - p) * finite_corr
else:
var_x = np.zeros_like(g, dtype=float)
var_er = var_x / (float(kk_eff) ** 2)
se_er_mean = float(np.sqrt(np.sum(var_er))) / float(max(1, n))
er_p05 = float(np.clip(er_mean - z * se_er_mean, 0.0, 1.0))
er_p95 = float(np.clip(er_mean + z * se_er_mean, 0.0, 1.0))
# Expected Recall@K for individuals with g>0 is K/K_total (clipped).
rec_mean = float(min(float(kk_eff) / float(k_total), 1.0))
if n_valid > 0:
var_rec = np.zeros_like(g, dtype=float)
gv = g[valid]
var_xv = var_x[valid]
# Var( X / g ) = Var(X) / g^2 (approx; g is fixed per individual)
var_rec_v = var_xv / (gv ** 2)
se_rec_mean = float(np.sqrt(np.sum(var_rec_v))) / float(n_valid)
rec_p05 = float(np.clip(rec_mean - z * se_rec_mean, 0.0, 1.0))
rec_p95 = float(np.clip(rec_mean + z * se_rec_mean, 0.0, 1.0))
else:
rec_p05 = float("nan")
rec_p95 = float("nan")
out.append(
{
"topk": int(kk_eff),
"baseline_event_rate_mean": er_mean,
"baseline_event_rate_p05": er_p05,
"baseline_event_rate_p95": er_p95,
"baseline_recall_mean": rec_mean,
"baseline_recall_p05": float(rec_p05),
"baseline_recall_p95": float(rec_p95),
"n_total": int(n),
"n_valid_recall": int(n_valid),
"k_total": int(k_total),
"baseline_method": "random_ranking_hypergeometric_normal_approx",
}
)
return out
def count_occurs_within_horizon(
loader: DataLoader,
offset_years: float,
tau_years: float,
n_disease: int,
device: str,
) -> Tuple[np.ndarray, int]:
"""Count per-person occurrence within tau after the prediction context.
Returns counts[k] = number of individuals with disease k at least once in (t_ctx, t_ctx+tau].
"""
counts = torch.zeros((n_disease,), dtype=torch.long, device=device)
n_total_eval = 0
for batch in loader:
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
keep, t_ctx, _ = select_context_indices(
event_seq, time_seq, offset_years)
if not keep.any():
continue
n_total_eval += int(keep.sum().item())
event_seq = event_seq[keep]
time_seq = time_seq[keep]
t_ctx = t_ctx[keep]
B, L = event_seq.shape
b = torch.arange(B, device=device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (float(tau_years) * DAYS_PER_YEAR)
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
if not in_window.any():
continue
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
# unique per (person, disease) to count per-person within-window occurrence
key = b_idx.to(torch.long) * int(n_disease) + disease_ids
uniq = torch.unique(key)
uniq_disease = uniq % int(n_disease)
counts.scatter_add_(0, uniq_disease, torch.ones_like(
uniq_disease, dtype=torch.long))
return counts.detach().cpu().numpy(), int(n_total_eval)
# ============================================================
# Evaluation core
# ============================================================
def instantiate_model_and_head(
cfg: Dict[str, Any],
dataset: HealthDataset,
device: str,
checkpoint_path: str = "",
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float], Dict[str, Any]]:
model_type = str(cfg["model_type"])
loss_type = str(cfg["loss_type"])
loss_params: Dict[str, Any] = {}
if loss_type == "exponential":
out_dims = [dataset.n_disease]
elif loss_type == "discrete_time_cif":
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
out_dims = [dataset.n_disease + 1, len(bin_edges)]
elif loss_type == "lognormal_basis_binned_hazard_cif":
centers = cfg.get("lognormal_centers", None)
if centers is None:
centers = cfg.get("centers", None)
if not isinstance(centers, list) or len(centers) == 0:
raise ValueError(
"lognormal_basis_binned_hazard_cif requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json"
)
r = len(centers)
desired_total = int(dataset.n_disease) * int(r)
legacy_total = 1 + desired_total
# Prefer the new shape (K,R) but keep compatibility with older checkpoints
# that used a single flattened dimension (1 + K*R).
out_dims = [int(dataset.n_disease), int(r)]
if checkpoint_path:
try:
ckpt = torch.load(checkpoint_path, map_location="cpu")
head_sd = ckpt.get("head_state_dict", {})
w = head_sd.get("net.2.weight", None)
if isinstance(w, torch.Tensor) and w.ndim == 2:
out_features = int(w.shape[0])
if out_features == legacy_total:
out_dims = [legacy_total]
elif out_features == desired_total:
out_dims = [int(dataset.n_disease), int(r)]
else:
raise ValueError(
f"Checkpoint head out_features={out_features} does not match expected {desired_total} (K*R) or {legacy_total} (1+K*R)"
)
except Exception as e:
raise ValueError(
f"Failed to infer head output dims from checkpoint={checkpoint_path}: {e}"
)
loss_params["centers"] = centers
loss_params["bandwidth_min"] = float(cfg.get("bandwidth_min", 1e-3))
loss_params["bandwidth_max"] = float(cfg.get("bandwidth_max", 10.0))
loss_params["bandwidth_init"] = float(cfg.get("bandwidth_init", 0.7))
loss_params["loss_eps"] = float(cfg.get("loss_eps", 1e-8))
loss_params["alpha_floor"] = float(cfg.get("alpha_floor", 0.0))
else:
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
if model_type == "delphi_fork":
backbone = DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
).to(device)
elif model_type == "sap_delphi":
# Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path.
emb_path = cfg.get("pretrained_emb_path", None)
if emb_path in {"", None}:
emb_path = cfg.get("pretrained_emd_path", None)
if emb_path in {"", None}:
run_dir = os.path.dirname(os.path.abspath(
checkpoint_path)) if checkpoint_path else ""
print(
f"WARNING: SapDelphi pretrained embedding path missing in config "
f"(expected 'pretrained_emb_path' or 'pretrained_emd_path'). "
f"checkpoint={checkpoint_path} run_dir={run_dir}"
)
backbone = SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
pretrained_weights_path=emb_path,
freeze_embeddings=True,
).to(device)
else:
raise ValueError(f"Unsupported model_type: {model_type}")
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
return backbone, head, loss_type, bin_edges, loss_params
@torch.no_grad()
def predict_cifs_for_model(
backbone: torch.nn.Module,
head: torch.nn.Module,
loss_type: str,
bin_edges: Sequence[float],
loss_params: Dict[str, Any],
loader: DataLoader,
device: str,
offset_years: float,
eval_horizons: Sequence[float],
n_disease: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Run model and produce cause-specific, time-dependent CIF outputs.
Returns:
cif_full: (N, K, H)
survival: (N, H)
y_cause_within_tau: (N, K, H)
NOTE: Evaluation is cause-specific and horizon-specific (multi-disease risk).
"""
backbone.eval()
head.eval()
# We will accumulate in CPU lists, then concat.
cif_full_list: List[np.ndarray] = []
survival_list: List[np.ndarray] = []
y_cause_within_list: List[np.ndarray] = []
sex_list: List[np.ndarray] = []
for batch in loader:
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
keep, t_ctx, _ = select_context_indices(
event_seq, time_seq, offset_years)
if not keep.any():
continue
# filter batch
event_seq = event_seq[keep]
time_seq = time_seq[keep]
cont_feats = cont_feats[keep]
cate_feats = cate_feats[keep]
sexes_k = sexes[keep]
t_ctx = t_ctx[keep]
h = backbone(event_seq, time_seq, sexes_k,
cont_feats, cate_feats) # (B,L,D)
b = torch.arange(h.size(0), device=device)
c = h[b, t_ctx] # (B,D)
logits = head(c)
if loss_type == "exponential":
cif_full, survival = cifs_from_exponential_logits(
logits, eval_horizons, return_survival=True) # (B,K,H), (B,H)
elif loss_type == "discrete_time_cif":
cif_full, survival = cifs_from_discrete_time_logits(
# (B,K,H), (B,H)
logits, bin_edges, eval_horizons, return_survival=True)
elif loss_type == "lognormal_basis_binned_hazard_cif":
centers = loss_params.get("centers", None)
sigma = loss_params.get("sigma", None)
if centers is None or sigma is None:
raise ValueError(
"lognormal_basis_binned_hazard_cif requires loss_params['centers'] and loss_params['sigma']")
cif_full, survival = cifs_from_lognormal_basis_binned_hazard_logits(
logits,
centers=centers,
sigma=sigma,
bin_edges=bin_edges,
taus=eval_horizons,
eps=float(loss_params.get("loss_eps", 1e-8)),
alpha_floor=float(loss_params.get("alpha_floor", 0.0)),
return_survival=True,
)
else:
raise ValueError(f"Unsupported loss_type: {loss_type}")
# Within-horizon labels for all causes: disease k occurs within tau after context.
y_within_full = torch.stack(
[
multi_hot_ever_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau),
n_disease=int(n_disease),
).to(torch.float32)
for tau in eval_horizons
],
dim=2,
) # (B,K,H)
cif_full_list.append(cif_full.detach().cpu().numpy())
survival_list.append(survival.detach().cpu().numpy())
y_cause_within_list.append(y_within_full.detach().cpu().numpy())
sex_list.append(sexes_k.detach().cpu().numpy())
if not cif_full_list:
raise RuntimeError(
"No valid samples for evaluation (all batches filtered out by offset).")
cif_full = np.concatenate(cif_full_list, axis=0)
survival = np.concatenate(survival_list, axis=0)
y_cause_within = np.concatenate(y_cause_within_list, axis=0)
sex = np.concatenate(
sex_list, axis=0) if sex_list else np.array([], dtype=int)
return cif_full, survival, y_cause_within, sex
def evaluate_one_model(
model_name: str,
cif_full: np.ndarray,
y_cause_within_tau: np.ndarray,
eval_horizons: Sequence[float],
out_rows: List[Dict[str, Any]],
calib_rows: List[Dict[str, Any]],
calib_cause_ids: Optional[Sequence[int]],
n_calib_bins: int = 10,
metric_workers: int = 0,
progress: str = "auto",
) -> None:
"""Compute per-cause metrics for ALL diseases.
Notes:
- Writes scalar metrics for all causes into out_rows.
- Writes calibration-bin points only for calib_cause_ids (to keep outputs tractable).
"""
cif_full = np.asarray(cif_full, dtype=float)
y_cause_within_tau = np.asarray(y_cause_within_tau, dtype=float)
if cif_full.ndim != 3 or y_cause_within_tau.ndim != 3:
raise ValueError(
"Expected cif_full and y_cause_within_tau with shape (N, K, H)")
if cif_full.shape != y_cause_within_tau.shape:
raise ValueError(
f"Shape mismatch: cif_full {cif_full.shape} vs y_cause_within_tau {y_cause_within_tau.shape}"
)
N, K, H = cif_full.shape
if H != len(eval_horizons):
raise ValueError("H mismatch between cif_full and eval_horizons")
calib_set = set(int(x)
for x in calib_cause_ids) if calib_cause_ids is not None else set()
workers = int(metric_workers)
if workers <= 0:
workers = int(min(8, os.cpu_count() or 1))
workers = max(1, workers)
show_progress = _should_show_progress(progress)
def _eval_chunk(
*,
tau: float,
p_tau: np.ndarray,
y_tau: np.ndarray,
brier_by_cause: np.ndarray,
cause_ids: np.ndarray,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], int]:
local_rows: List[Dict[str, Any]] = []
local_calib: List[Dict[str, Any]] = []
for cid in cause_ids.tolist():
p = p_tau[:, cid]
y = y_tau[:, cid]
local_rows.append(
{
"model_name": model_name,
"metric_name": "cause_brier",
"horizon": float(tau),
"cause": int(cid),
"value": float(brier_by_cause[cid]),
"ci_low": "",
"ci_high": "",
}
)
# ICI: compute bins only if we will export them.
need_bins = (not calib_set) or (int(cid) in calib_set)
if need_bins:
cal = calibration_deciles(p, y, n_bins=n_calib_bins)
ici = float(cal["ici"])
else:
cal = None
ici = calibration_ici_only(p, y, n_bins=n_calib_bins)
local_rows.append(
{
"model_name": model_name,
"metric_name": "cause_ici",
"horizon": float(tau),
"cause": int(cid),
"value": float(ici),
"ci_low": "",
"ci_high": "",
}
)
# Secondary: discrimination via AUC at the same horizon (point estimate only).
auc = roc_auc_rank(y, p)
local_rows.append(
{
"model_name": model_name,
"metric_name": "cause_auc",
"horizon": float(tau),
"cause": int(cid),
"value": float(auc),
"ci_low": "",
"ci_high": "",
}
)
if need_bins and cal is not None:
for binfo in cal.get("bins", []):
local_calib.append(
{
"model_name": model_name,
"task": "cause_k",
"horizon": float(tau),
"cause_id": int(cid),
"bin_index": int(binfo["bin"]),
"p_mean": float(binfo["p_mean"]),
"y_mean": float(binfo["y_mean"]),
"n_in_bin": int(binfo["n"]),
}
)
return local_rows, local_calib, int(cause_ids.size)
# Cause-specific, time-dependent metrics per horizon.
for h_i, tau in enumerate(eval_horizons):
p_tau = cif_full[:, :, h_i] # (N, K)
y_tau = y_cause_within_tau[:, :, h_i] # (N, K)
# Vectorized Brier for speed.
brier_by_cause = np.mean((p_tau - y_tau) ** 2, axis=0) # (K,)
# Parallelize disease-level metrics; chunk to avoid millions of futures.
all_ids = np.arange(int(K), dtype=int)
chunks = np.array_split(all_ids, workers)
done = 0
prefix = f"[{model_name}] tau={float(tau)}y "
t0 = time.time()
if workers <= 1:
for ch in chunks:
r_chunk, c_chunk, n_done = _eval_chunk(
tau=float(tau),
p_tau=p_tau,
y_tau=y_tau,
brier_by_cause=brier_by_cause,
cause_ids=ch,
)
out_rows.extend(r_chunk)
calib_rows.extend(c_chunk)
done += int(n_done)
if show_progress:
sys.stdout.write(
"\r" + _progress_line(done, int(K), prefix=prefix))
sys.stdout.flush()
else:
with ThreadPoolExecutor(max_workers=workers) as ex:
futs = [
ex.submit(
_eval_chunk,
tau=float(tau),
p_tau=p_tau,
y_tau=y_tau,
brier_by_cause=brier_by_cause,
cause_ids=ch,
)
for ch in chunks
if int(ch.size) > 0
]
for fut in as_completed(futs):
r_chunk, c_chunk, n_done = fut.result()
out_rows.extend(r_chunk)
calib_rows.extend(c_chunk)
done += int(n_done)
if show_progress:
sys.stdout.write(
"\r" + _progress_line(done, int(K), prefix=prefix))
sys.stdout.flush()
if show_progress:
dt = time.time() - t0
sys.stdout.write("\r" + _progress_line(int(K),
int(K), prefix=prefix) + f" ({dt:.1f}s)\n")
sys.stdout.flush()
def summarize_over_diseases(
rows: List[Dict[str, Any]],
*,
model_name: str,
eval_horizons: Sequence[float],
metrics: Sequence[str] = ("cause_brier", "cause_ici", "cause_auc"),
) -> List[Dict[str, Any]]:
"""Summarize mean/median of each metric over diseases (per horizon)."""
out: List[Dict[str, Any]] = []
# Build metric_name -> horizon -> list of values
bucket: Dict[Tuple[str, float], List[float]] = {}
for r in rows:
if r.get("model_name") != model_name:
continue
m = str(r.get("metric_name"))
if m not in set(metrics):
continue
h = _safe_float(r.get("horizon"))
v = _safe_float(r.get("value"))
if not np.isfinite(h):
continue
if not np.isfinite(v):
continue
bucket.setdefault((m, float(h)), []).append(float(v))
for tau in eval_horizons:
ht = float(tau)
for m in metrics:
vals = bucket.get((str(m), ht), [])
if vals:
arr = np.asarray(vals, dtype=float)
mean_v = float(np.mean(arr))
med_v = float(np.median(arr))
n_valid = int(arr.size)
else:
mean_v = float("nan")
med_v = float("nan")
n_valid = 0
out.append(
{
"model_name": str(model_name),
"metric_name": str(m),
"horizon": ht,
"mean": mean_v,
"median": med_v,
"n_valid": n_valid,
}
)
return out
def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None:
fieldnames = [
"model_name",
"task",
"horizon",
"cause_id",
"bin_index",
"p_mean",
"y_mean",
"n_in_bin",
]
with open(path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow(r)
def write_results_csv(path: str, rows: List[Dict[str, Any]]) -> None:
fieldnames = [
"model_name",
"metric_name",
"horizon",
"cause",
"value",
"ci_low",
"ci_high",
]
with open(path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow(r)
def _make_eval_tag(split: str, offset_years: float) -> str:
"""Short tag for filenames written into run directories."""
off = f"{float(offset_years):.4f}".rstrip("0").rstrip(".")
return f"{split}_offset{off}y"
def main() -> int:
ap = argparse.ArgumentParser(
description="Unified downstream evaluation via CIFs")
ap.add_argument("--models_json", type=str, required=True,
help="Path to models list JSON")
ap.add_argument("--data_prefix", type=str,
default="ukb", help="Dataset prefix")
ap.add_argument("--split", type=str, default="test",
choices=["train", "val", "test", "all"], help="Which split to evaluate")
ap.add_argument("--offset_years", type=float, default=0.5,
help="Anti-leakage offset (years)")
ap.add_argument("--eval_horizons", type=float,
nargs="*", default=DEFAULT_EVAL_HORIZONS)
ap.add_argument("--batch_size", type=int, default=128)
ap.add_argument("--num_workers", type=int, default=0)
ap.add_argument("--seed", type=int, default=123)
ap.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
ap.add_argument("--out_csv", type=str, default="eval_summary.csv")
ap.add_argument("--out_meta_json", type=str, default="eval_meta.json")
# Integrity checks
ap.add_argument("--integrity_strict", action="store_true", default=False)
ap.add_argument("--integrity_tol", type=float, default=1e-6)
# Speed/UX
ap.add_argument(
"--metric_workers",
type=int,
default=0,
help="Threads for per-disease metrics (0=auto, 1=disable parallelism)",
)
ap.add_argument(
"--progress",
type=str,
default="auto",
choices=["auto", "bar", "none"],
help="Progress visualization during per-disease evaluation",
)
# Export settings for user-facing experiments
ap.add_argument("--export_dir", type=str, default="eval_exports")
ap.add_argument("--death_cause_id", type=int,
default=DEFAULT_DEATH_CAUSE_ID)
ap.add_argument("--focus_k", type=int, default=5,
help="Additional non-death causes to include")
ap.add_argument("--capture_k_pcts", type=int,
nargs="*", default=[1, 5, 10, 20])
ap.add_argument(
"--capture_curve_max_pct",
type=int,
default=50,
help="If >0, also export a dense capture curve for k=1..max_pct",
)
# High-risk cause concentration (cross-cause prioritization)
ap.add_argument(
"--cause_concentration_topk",
type=int,
nargs="*",
default=[5, 10, 20, 50],
help="Top-K causes per individual for Event Rate@K (cross-cause prioritization)",
)
ap.add_argument(
"--cause_concentration_write_random_baseline",
action="store_true",
default=False,
help="If set, also export a random-ranking baseline (expected Event Rate@K and Recall@K with an uncertainty range)",
)
args = ap.parse_args()
set_deterministic(args.seed)
specs = load_models_json(args.models_json)
if not specs:
raise ValueError("No models provided")
export_dir = str(args.export_dir)
_ensure_dir(export_dir)
cause_names = load_cause_names("labels.csv")
# Determine top-K causes from the evaluation split only (model-agnostic),
# aligned to time-dependent risk: occurrence within tau_max after context.
first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path)
cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [
"bmi", "smoking", "alcohol"]
dataset_for_top = HealthDataset(
data_prefix=args.data_prefix, covariate_list=cov_list)
subset_for_top = build_eval_subset(
dataset_for_top,
train_ratio=float(first_cfg.get("train_ratio", 0.7)),
val_ratio=float(first_cfg.get("val_ratio", 0.15)),
seed=int(first_cfg.get("random_seed", 42)),
split=args.split,
)
loader_top = DataLoader(
subset_for_top,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=health_collate_fn,
)
tau_max = float(max(args.eval_horizons))
counts, n_total_eval = count_occurs_within_horizon(
loader=loader_top,
offset_years=args.offset_years,
tau_years=tau_max,
n_disease=dataset_for_top.n_disease,
device=args.device,
)
focus_causes = pick_focus_causes(
counts_within_tau=counts,
n_disease=int(dataset_for_top.n_disease),
death_cause_id=int(args.death_cause_id),
k=int(args.focus_k),
)
top_cause_ids = np.asarray(focus_causes, dtype=int)
# Export the chosen focus causes.
focus_rows: List[Dict[str, Any]] = []
for r, cid in enumerate(focus_causes, start=1):
row: Dict[str, Any] = {"cause": int(cid), "rank": int(r)}
if cid in cause_names:
row["cause_name"] = cause_names[cid]
focus_rows.append(row)
focus_fieldnames = ["cause", "rank"] + \
(["cause_name"] if any("cause_name" in r for r in focus_rows) else [])
write_simple_csv(os.path.join(export_dir, "focus_causes.csv"),
focus_fieldnames, focus_rows)
# Metadata for focus causes (within tau_max).
top_causes_meta: List[Dict[str, Any]] = []
for cid in focus_causes:
n_case = int(counts[int(cid)]) if int(
cid) < int(counts.shape[0]) else 0
top_causes_meta.append(
{
"cause_id": int(cid),
"tau_years": float(tau_max),
"n_case_within_tau": n_case,
"n_control_within_tau": int(n_total_eval - n_case),
"n_total_eval": int(n_total_eval),
}
)
summary_rows: List[Dict[str, Any]] = []
calib_rows: List[Dict[str, Any]] = []
# Experiment exports (accumulated across models)
cap_points_rows: List[Dict[str, Any]] = []
cap_curve_rows: List[Dict[str, Any]] = []
conc_rows: List[Dict[str, Any]] = []
conc_base_rows: List[Dict[str, Any]] = []
# Track per-model integrity status for meta JSON.
integrity_meta: Dict[str, Any] = {}
# Evaluate each model
for spec in specs:
run_dir = os.path.dirname(os.path.abspath(spec.checkpoint_path))
tag = _make_eval_tag(args.split, float(args.offset_years))
# Remember list offsets so we can write per-model slices to the model's run_dir.
calib_start = len(calib_rows)
cfg = load_train_config_for_checkpoint(spec.checkpoint_path)
# Identifiers for consistent exports
model_id = str(spec.name)
model_type = str(
cfg.get("model_type", spec.model_type if hasattr(spec, "model_type") else ""))
loss_type_id = str(
cfg.get("loss_type", spec.loss_type if hasattr(spec, "loss_type") else ""))
age_encoder = str(cfg.get("age_encoder", ""))
cov_type = "full" if _parse_bool(
cfg.get("full_cov", False)) else "partial"
cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [
"bmi", "smoking", "alcohol"]
dataset = HealthDataset(
data_prefix=args.data_prefix, covariate_list=cov_list)
subset = build_eval_subset(
dataset,
train_ratio=float(cfg.get("train_ratio", 0.7)),
val_ratio=float(cfg.get("val_ratio", 0.15)),
seed=int(cfg.get("random_seed", 42)),
split=args.split,
)
loader = DataLoader(
subset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=health_collate_fn,
)
ckpt = torch.load(spec.checkpoint_path, map_location=args.device)
backbone, head, loss_type, bin_edges, loss_params = instantiate_model_and_head(
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
head.load_state_dict(ckpt["head_state_dict"], strict=True)
if loss_type == "lognormal_basis_binned_hazard_cif":
crit_state = ckpt.get("criterion_state_dict", {})
log_sigma = crit_state.get("log_sigma", None)
if isinstance(log_sigma, torch.Tensor):
log_sigma_t = log_sigma.to(device=args.device)
sigma = torch.exp(log_sigma_t)
else:
sigma = torch.tensor(float(loss_params.get(
"bandwidth_init", 0.7)), device=args.device)
bmin = float(loss_params.get("bandwidth_min", 1e-3))
bmax = float(loss_params.get("bandwidth_max", 10.0))
sigma = torch.clamp(sigma, min=bmin, max=bmax)
loss_params["sigma"] = sigma
(
cif_full,
survival,
y_cause_within_tau,
sex,
) = predict_cifs_for_model(
backbone,
head,
loss_type,
bin_edges,
loss_params,
loader,
args.device,
args.offset_years,
args.eval_horizons,
n_disease=int(dataset.n_disease),
)
# CIF integrity checks before metrics.
integrity_ok, integrity_notes = check_cif_integrity(
cif_full,
args.eval_horizons,
tol=float(args.integrity_tol),
name=spec.name,
strict=bool(args.integrity_strict),
survival=survival,
)
integrity_meta[spec.name] = {
"integrity_ok": bool(integrity_ok),
"integrity_notes": integrity_notes,
}
# Per-disease metrics for ALL diseases (written into the model's run_dir).
model_rows: List[Dict[str, Any]] = []
evaluate_one_model(
model_name=spec.name,
cif_full=cif_full,
y_cause_within_tau=y_cause_within_tau,
eval_horizons=args.eval_horizons,
out_rows=model_rows,
calib_rows=calib_rows,
calib_cause_ids=top_cause_ids.tolist(),
metric_workers=int(args.metric_workers),
progress=str(args.progress),
)
# Summary over diseases (mean/median per horizon).
model_summary_rows = summarize_over_diseases(
model_rows,
model_name=spec.name,
eval_horizons=args.eval_horizons,
)
summary_rows.extend(model_summary_rows)
# ============================================================
# Experiment: High-Risk Cause Concentration at fixed horizon
# (cross-cause prioritization accuracy)
# ============================================================
topk_causes = [int(x) for x in args.cause_concentration_topk]
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
for h_i, tau in enumerate(args.eval_horizons):
p_tau_all = np.asarray(cif_full[:, :, h_i], dtype=float)
y_tau_all = np.asarray(
y_cause_within_tau[:, :, h_i], dtype=float)
if sex_mask is not None:
p_tau_all = p_tau_all[sex_mask]
y_tau_all = y_tau_all[sex_mask]
for rr in compute_event_rate_at_topk_causes(p_tau_all, y_tau_all, topk_causes):
conc_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"horizon": float(tau),
"sex": sex_label,
**rr,
}
)
if bool(args.cause_concentration_write_random_baseline):
for rr in compute_random_ranking_baseline_topk(y_tau_all, topk_causes):
conc_base_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"horizon": float(tau),
"sex": sex_label,
**rr,
}
)
# Convenience slices for user-facing experiments (focus causes only).
cause_cif_focus = cif_full[:, top_cause_ids, :]
y_within_focus = y_cause_within_tau[:, top_cause_ids, :]
# ============================================================
# Experiment 2: High-risk capture points (+ optional curve)
# ============================================================
k_pcts = [int(x) for x in args.capture_k_pcts]
curve_max = int(args.capture_curve_max_pct)
curve_grid = list(range(1, curve_max + 1)
) if curve_max and curve_max > 0 else []
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
for h_i, tau in enumerate(args.eval_horizons):
for j, cause_id in enumerate(top_cause_ids.tolist()):
p = cause_cif_focus[:, j, h_i]
y = y_within_focus[:, j, h_i]
if sex_mask is not None:
p = p[sex_mask]
y = y[sex_mask]
for r in compute_capture_points(p, y, k_pcts):
cap_points_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"horizon": float(tau),
"sex": sex_label,
**r,
}
)
if curve_grid:
for r in compute_capture_points(p, y, curve_grid):
cap_curve_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"horizon": float(tau),
"sex": sex_label,
**r,
}
)
# Optionally write top-cause counts into the main results CSV as metric rows.
for tc in top_causes_meta:
model_rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_case_within_tau",
"horizon": float(tc["tau_years"]),
"cause": int(tc["cause_id"]),
"value": int(tc["n_case_within_tau"]),
"ci_low": "",
"ci_high": "",
}
)
model_rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_control_within_tau",
"horizon": float(tc["tau_years"]),
"cause": int(tc["cause_id"]),
"value": int(tc["n_control_within_tau"]),
"ci_low": "",
"ci_high": "",
}
)
model_rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_total_eval",
"horizon": float(tc["tau_years"]),
"cause": int(tc["cause_id"]),
"value": int(tc["n_total_eval"]),
"ci_low": "",
"ci_high": "",
}
)
# Write per-model results into the model's run directory.
model_calib_rows = calib_rows[calib_start:]
model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv")
model_summary_csv = os.path.join(run_dir, f"eval_summary_{tag}.csv")
model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv")
model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json")
write_results_csv(model_out_csv, model_rows)
write_simple_csv(
model_summary_csv,
["model_name", "metric_name", "horizon", "mean", "median", "n_valid"],
model_summary_rows,
)
write_calibration_bins_csv(model_calib_csv, model_calib_rows)
model_meta = {
"model_name": spec.name,
"checkpoint_path": spec.checkpoint_path,
"run_dir": run_dir,
"split": args.split,
"offset_years": args.offset_years,
"eval_horizons": [float(x) for x in args.eval_horizons],
"n_disease": int(dataset.n_disease),
"top_cause_ids": top_cause_ids.tolist(),
"top_causes": top_causes_meta,
"integrity": {spec.name: integrity_meta.get(spec.name, {})},
"paths": {
"results_csv": model_out_csv,
"summary_csv": model_summary_csv,
"calibration_bins_csv": model_calib_csv,
},
}
with open(model_meta_json, "w") as f:
json.dump(model_meta, f, indent=2)
print(f"Wrote per-model results to {model_out_csv}")
# Write global summary (across diseases) across all models.
write_simple_csv(
args.out_csv,
["model_name", "metric_name", "horizon", "mean", "median", "n_valid"],
summary_rows,
)
# Write calibration curve points to a separate CSV.
out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "."
calib_csv_path = os.path.join(out_dir, "calibration_bins.csv")
write_calibration_bins_csv(calib_csv_path, calib_rows)
# Write experiment exports
write_simple_csv(
os.path.join(export_dir, "lift_capture_points.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"horizon",
"sex",
"k_pct",
"n_targeted",
"events_targeted",
"events_total",
"event_capture_rate",
"precision_in_targeted",
],
cap_points_rows,
)
if cap_curve_rows:
write_simple_csv(
os.path.join(export_dir, "lift_capture_curve.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"horizon",
"sex",
"k_pct",
"n_targeted",
"events_targeted",
"events_total",
"event_capture_rate",
"precision_in_targeted",
],
cap_curve_rows,
)
write_simple_csv(
os.path.join(export_dir, "high_risk_cause_concentration.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"horizon",
"sex",
"topk",
"event_rate_mean",
"event_rate_median",
"recall_mean",
"recall_median",
"n_total",
"n_valid_recall",
],
conc_rows,
)
if conc_base_rows:
write_simple_csv(
os.path.join(
export_dir, "high_risk_cause_concentration_random_baseline.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"horizon",
"sex",
"topk",
"baseline_event_rate_mean",
"baseline_event_rate_p05",
"baseline_event_rate_p95",
"baseline_recall_mean",
"baseline_recall_p05",
"baseline_recall_p95",
"n_total",
"n_valid_recall",
"k_total",
"baseline_method",
],
conc_base_rows,
)
# Manifest markdown (stable, user-facing)
manifest_path = os.path.join(export_dir, "eval_exports_manifest.md")
with open(manifest_path, "w", encoding="utf-8") as f:
f.write(
"# Evaluation Exports Manifest\n\n"
"This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). "
"All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n"
"## Files\n\n"
"- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n"
"- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n"
"- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n"
"- high_risk_cause_concentration.csv: Event Rate@K and Recall@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart vs K.\n"
"- high_risk_cause_concentration_random_baseline.csv (optional): Random-ranking baseline for Event Rate@K and Recall@K with an uncertainty range (enabled by --cause_concentration_write_random_baseline).\n"
)
meta = {
"split": args.split,
"offset_years": args.offset_years,
"eval_horizons": [float(x) for x in args.eval_horizons],
"tau_max": float(tau_max),
"n_disease": int(dataset_for_top.n_disease),
"top_cause_ids": top_cause_ids.tolist(),
"top_causes": top_causes_meta,
"integrity": integrity_meta,
"notes": {
"label": "Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau])",
"primary_metrics": "cause_brier (CIF-based) and cause_ici (calibration)",
"secondary_metrics": "cause_auc (discrimination)",
"exclusions": "No all-cause aggregation; no next-event formulation; ECE not reported",
"warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.",
"exports_dir": export_dir,
"focus_causes": focus_causes,
},
}
with open(args.out_meta_json, "w") as f:
json.dump(meta, f, indent=2)
print(f"Wrote {args.out_csv} with {len(summary_rows)} rows")
print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows")
print(f"Wrote {args.out_meta_json}")
return 0
if __name__ == "__main__":
raise SystemExit(main())