Files
DeepHealth/evaluate_models.py

2413 lines
85 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import csv
import json
import math
import os
import random
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from dataset import HealthDataset, health_collate_fn
from model import DelphiFork, SapDelphi, SimpleHead
# ============================================================
# Constants / defaults (aligned with evaluate_prompt.md)
# ============================================================
DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")]
DEFAULT_EVAL_HORIZONS = [0.72, 1.61, 3.84, 10.0]
DAYS_PER_YEAR = 365.25
DEFAULT_DEATH_CAUSE_ID = 1256
# ============================================================
# Model specs
# ============================================================
@dataclass(frozen=True)
class ModelSpec:
name: str
model_type: str # delphi_fork | sap_delphi
loss_type: str # exponential | discrete_time_cif | lognormal_basis_binned_hazard_cif
full_cov: bool
checkpoint_path: str
# ============================================================
# Determinism
# ============================================================
def set_deterministic(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# ============================================================
# Utilities
# ============================================================
def _parse_bool(x: Any) -> bool:
if isinstance(x, bool):
return x
s = str(x).strip().lower()
if s in {"true", "1", "yes", "y"}:
return True
if s in {"false", "0", "no", "n"}:
return False
raise ValueError(f"Cannot parse boolean: {x!r}")
def load_models_json(path: str) -> List[ModelSpec]:
with open(path, "r") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("models_json must be a list of model entries")
specs: List[ModelSpec] = []
for row in data:
specs.append(
ModelSpec(
name=str(row["name"]),
model_type=str(row["model_type"]),
loss_type=str(row["loss_type"]),
full_cov=_parse_bool(row["full_cov"]),
checkpoint_path=str(row["checkpoint_path"]),
)
)
return specs
def load_train_config_for_checkpoint(checkpoint_path: str) -> Dict[str, Any]:
run_dir = os.path.dirname(os.path.abspath(checkpoint_path))
cfg_path = os.path.join(run_dir, "train_config.json")
with open(cfg_path, "r") as f:
cfg = json.load(f)
return cfg
def build_eval_subset(
dataset: HealthDataset,
train_ratio: float,
val_ratio: float,
seed: int,
split: str,
):
n_total = len(dataset)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
train_ds, val_ds, test_ds = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed),
)
if split == "train":
return train_ds
if split == "val":
return val_ds
if split == "test":
return test_ds
if split == "all":
return dataset
raise ValueError("split must be one of: train, val, test, all")
# ============================================================
# Context selection (anti-leakage)
# ============================================================
def select_context_indices(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
offset_years: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Select per-sample prediction context index.
IMPORTANT SEMANTICS:
- The last observed token time is treated as the FOLLOW-UP END time.
- We pick the last valid token with time <= (followup_end_time - offset).
- We do NOT interpret followup_end_time as an event time.
Returns:
keep_mask: (B,) bool, which samples have a valid context
t_ctx: (B,) long, index into sequence
t_ctx_time: (B,) float, time (days) at context
"""
# valid tokens are event != 0 (padding is 0)
valid = event_seq != 0
lengths = valid.sum(dim=1)
last_idx = torch.clamp(lengths - 1, min=0)
b = torch.arange(event_seq.size(0), device=event_seq.device)
followup_end_time = time_seq[b, last_idx]
t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR)
eligible = valid & (time_seq <= t_cut.unsqueeze(1))
eligible_counts = eligible.sum(dim=1)
keep = eligible_counts > 0
t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long)
t_ctx_time = time_seq[b, t_ctx]
return keep, t_ctx, t_ctx_time
def next_event_after_context(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return next disease event after context.
Returns:
dt_years: (B,) float, time to next disease in years; +inf if none
cause: (B,) long, disease id in [0,K) for next event; -1 if none
"""
B, L = event_seq.shape
b = torch.arange(B, device=event_seq.device)
t0 = time_seq[b, t_ctx]
# Allow same-day events while excluding the context token itself.
# We rely on time-sorted sequences and select the FIRST valid future event by index.
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
idx_min = torch.where(
future, idxs, torch.full_like(idxs, L)).min(dim=1).values
has = idx_min < L
t_next = torch.where(has, idx_min, torch.zeros_like(idx_min))
t_next_time = time_seq[b, t_next]
dt_days = t_next_time - t0
dt_years = dt_days / DAYS_PER_YEAR
dt_years = torch.where(
has, dt_years, torch.full_like(dt_years, float("inf")))
cause_token = event_seq[b, t_next]
cause = (cause_token - 2).to(torch.long)
cause = torch.where(has, cause, torch.full_like(cause, -1))
return dt_years, cause
def multi_hot_ever_within_horizon(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
tau_years: float,
n_disease: int,
) -> torch.Tensor:
"""Binary labels: disease k occurs within tau after context (any occurrence)."""
B, L = event_seq.shape
b = torch.arange(B, device=event_seq.device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (tau_years * DAYS_PER_YEAR)
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
# Include same-day events after context, exclude any token at/before context index.
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
if not in_window.any():
return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
y[b_idx, disease_ids] = True
return y
def multi_hot_ever_after_context_anytime(
event_seq: torch.Tensor,
t_ctx: torch.Tensor,
n_disease: int,
) -> torch.Tensor:
"""Binary labels: disease k occurs ANYTIME after the prediction context.
This is Delphi2M-compatible for Task A case/control definition.
Same-day events are included as long as they occur after the context token index.
"""
B, L = event_seq.shape
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
if not future.any():
return y
b_idx, t_idx = future.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
y[b_idx, disease_ids] = True
return y
def multi_hot_selected_causes_within_horizon(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
tau_years: float,
cause_ids: torch.Tensor,
n_disease: int,
) -> torch.Tensor:
"""Labels for selected causes only: does cause k occur within tau after context?"""
B, L = event_seq.shape
device = event_seq.device
b = torch.arange(B, device=device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (tau_years * DAYS_PER_YEAR)
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device)
if not in_window.any():
return out
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
# Filter to selected causes via a boolean membership mask over the global disease space.
selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device)
selected[cause_ids] = True
keep = selected[disease_ids]
if not keep.any():
return out
b_idx = b_idx[keep]
disease_ids = disease_ids[keep]
# Map disease_id -> local index in cause_ids
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device)
lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device)
local = lookup[disease_ids]
out[b_idx, local] = True
return out
# ============================================================
# CIF conversion
# ============================================================
def cifs_from_exponential_logits(
logits: torch.Tensor,
taus: Sequence[float],
eps: float = 1e-6,
return_survival: bool = False,
) -> torch.Tensor:
"""Convert exponential cause-specific logits -> CIFs at taus.
logits: (B, K)
returns: (B, K, H) or (cif, survival) if return_survival
"""
hazards = F.softplus(logits) + eps
total = hazards.sum(dim=1, keepdim=True) # (B,1)
taus_t = torch.tensor(list(taus), device=logits.device,
dtype=hazards.dtype).view(1, 1, -1)
total_h = total.unsqueeze(-1) # (B,1,1)
# (1 - exp(-Lambda * tau))
one_minus_surv = 1.0 - torch.exp(-total_h * taus_t)
frac = hazards / torch.clamp(total, min=eps)
cif = frac.unsqueeze(-1) * one_minus_surv # (B,K,H)
# If total==0, set to 0
cif = torch.where(total_h > 0, cif, torch.zeros_like(cif))
if not return_survival:
return cif
survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H)
# Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors.
nonzero = (total.squeeze(1) > 0).unsqueeze(1)
survival = torch.where(nonzero, survival, torch.ones_like(survival))
return cif, survival
def cifs_from_discrete_time_logits(
logits: torch.Tensor,
bin_edges: Sequence[float],
taus: Sequence[float],
return_survival: bool = False,
) -> torch.Tensor:
"""Convert discrete-time CIF logits -> CIFs at taus.
logits: (B, K+1, n_bins+1)
bin_edges: len=n_bins+1 (including 0 and inf)
taus: subset of finite bin edges (recommended)
returns: (B, K, H) or (cif, survival) if return_survival
"""
if logits.ndim != 3:
raise ValueError("Expected logits shape (B, K+1, n_bins+1)")
B, K_plus_1, n_bins_plus_1 = logits.shape
K = K_plus_1 - 1
edges = [float(x) for x in bin_edges]
# drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf
finite_edges = [e for e in edges[1:] if math.isfinite(e)]
n_bins = len(finite_edges)
if n_bins_plus_1 != len(edges):
raise ValueError("logits last dim must match len(bin_edges)")
probs = torch.softmax(logits, dim=1) # (B, K+1, n_bins+1)
# use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot)
hazards = probs[:, :K, 1: 1 + n_bins] # (B,K,n_bins)
p_comp = probs[:, K, 1: 1 + n_bins] # (B,n_bins)
# survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v]
ones = torch.ones((B, 1), device=logits.device, dtype=probs.dtype)
cum = torch.cumprod(p_comp, dim=1)
s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # (B, n_bins)
cif_bins = torch.cumsum(s_prev.unsqueeze(
1) * hazards, dim=2) # (B,K,n_bins)
# Robust mapping from tau -> edge index (floating-point safe).
# taus are expected to align with bin edges, but may differ slightly due to parsing/serialization.
finite_edges_arr = np.asarray(finite_edges, dtype=float)
tau_to_idx: List[int] = []
for tau in taus:
tau_f = float(tau)
if not math.isfinite(tau_f):
raise ValueError("taus must be finite for discrete-time CIF")
diffs = np.abs(finite_edges_arr - tau_f)
j = int(np.argmin(diffs))
if diffs[j] > 1e-6:
raise ValueError(
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})"
)
tau_to_idx.append(j)
idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long)
cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H)
if not return_survival:
return cif
# Survival at each horizon = prod_{u <= idx[h]} p_comp[u]
survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v]
survival = survival_bins.index_select(dim=1, index=idx) # (B,H)
return cif, survival
def _normal_cdf_stable(z: torch.Tensor) -> torch.Tensor:
z = torch.clamp(z, -12.0, 12.0)
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
def cifs_from_lognormal_basis_binned_hazard_logits(
logits: torch.Tensor,
*,
centers: Sequence[float],
sigma: torch.Tensor,
bin_edges: Sequence[float],
taus: Sequence[float],
eps: float = 1e-8,
alpha_floor: float = 0.0,
return_survival: bool = False,
) -> torch.Tensor:
"""Convert Route-3 binned hazard logits -> CIFs at taus.
logits: (B, J, R) OR (B, J*R) OR (B, 1+J*R) (leading column ignored).
taus are expected to align with finite bin edges.
"""
if logits.ndim not in {2, 3}:
raise ValueError("logits must be 2D or 3D")
if sigma.ndim != 0:
raise ValueError("sigma must be a scalar tensor")
device = logits.device
dtype = logits.dtype
centers_t = torch.tensor([float(x)
for x in centers], device=device, dtype=dtype)
r = int(centers_t.numel())
if r <= 0:
raise ValueError("centers must be non-empty")
offset = 0
if logits.ndim == 3:
j = int(logits.shape[1])
if int(logits.shape[2]) != r:
raise ValueError(
f"logits.shape[2] must equal R={r}; got {int(logits.shape[2])}"
)
else:
d = int(logits.shape[1])
if d % r == 0:
jr = d
elif (d - 1) % r == 0:
offset = 1
jr = d - 1
else:
raise ValueError(
f"logits.shape[1] must be divisible by R={r} (or 1+J*R); got {d}")
j = jr // r
if j <= 0:
raise ValueError("Inferred J must be >= 1")
edges = [float(x) for x in bin_edges]
finite_edges = [e for e in edges[1:] if math.isfinite(e)]
n_bins = len(finite_edges)
if n_bins <= 0:
raise ValueError("bin_edges must contain at least one finite edge")
# Build finite bins [edges[k-1], edges[k]) for k=1..n_bins
left = torch.tensor(edges[:n_bins], device=device, dtype=dtype)
right = torch.tensor(edges[1:1 + n_bins], device=device, dtype=dtype)
# Stable t_min clamp (aligns with training loss rule).
t_min = 1e-12
if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0:
t_min = edges[1] * 1e-6
t_min_t = torch.tensor(float(t_min), device=device, dtype=dtype)
left_is_zero = left <= 0
left_clamped = torch.clamp(left, min=t_min_t)
log_left = torch.log(left_clamped)
right_clamped = torch.clamp(right, min=t_min_t)
log_right = torch.log(right_clamped)
sigma_c = sigma.to(device=device, dtype=dtype)
z_left = (log_left.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c
z_right = (log_right.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c
cdf_left = _normal_cdf_stable(z_left)
if left_is_zero.any():
cdf_left = torch.where(left_is_zero.unsqueeze(-1),
torch.zeros_like(cdf_left), cdf_left)
cdf_right = _normal_cdf_stable(z_right)
delta_basis = torch.clamp(cdf_right - cdf_left, min=0.0) # (n_bins, R)
if logits.ndim == 3:
alpha = F.softplus(logits) + float(alpha_floor) # (B,J,R)
else:
logits_used = logits[:, offset:]
alpha = (F.softplus(logits_used) + float(alpha_floor)
).view(logits.size(0), j, r) # (B,J,R)
h_jk = torch.einsum("bjr,kr->bjk", alpha, delta_basis) # (B,J,n_bins)
h_k = h_jk.sum(dim=1) # (B,n_bins)
h_k = torch.clamp(h_k, min=eps)
h_jk = torch.clamp(h_jk, min=eps)
p_comp = torch.exp(-h_k) # (B,n_bins)
one_minus = -torch.expm1(-h_k) # (B,n_bins) = 1-exp(-H)
ratio = h_jk / torch.clamp(h_k.unsqueeze(1), min=eps)
p_event = one_minus.unsqueeze(1) * ratio # (B,J,n_bins)
ones = torch.ones((alpha.size(0), 1), device=device, dtype=dtype)
cum = torch.cumprod(p_comp, dim=1) # survival after each bin
s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # survival before each bin
cif_bins = torch.cumsum(s_prev.unsqueeze(
1) * p_event, dim=2) # (B,J,n_bins)
finite_edges_arr = np.asarray(finite_edges, dtype=float)
tau_to_idx: List[int] = []
for tau in taus:
tau_f = float(tau)
if not math.isfinite(tau_f):
raise ValueError("taus must be finite for discrete-time CIF")
diffs = np.abs(finite_edges_arr - tau_f)
idx0 = int(np.argmin(diffs))
if diffs[idx0] > 1e-6:
raise ValueError(
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[idx0]})"
)
tau_to_idx.append(idx0)
idx = torch.tensor(tau_to_idx, device=device, dtype=torch.long)
cif = cif_bins.index_select(dim=2, index=idx) # (B,J,H)
if not return_survival:
return cif
survival = cum.index_select(dim=1, index=idx) # (B,H)
return cif, survival
# ============================================================
# CIF integrity checks
# ============================================================
def check_cif_integrity(
cause_cif: np.ndarray,
horizons: Sequence[float],
*,
tol: float = 1e-6,
name: str = "",
strict: bool = False,
survival: Optional[np.ndarray] = None,
) -> Tuple[bool, List[str]]:
"""Run basic sanity checks on CIF arrays.
Args:
cause_cif: (N, K, H)
horizons: length H
tol: tolerance for inequalities
name: model name for messages
strict: if True, raise ValueError on first failure
survival: optional (N, H) survival values at the same horizons
Returns:
(integrity_ok, notes)
"""
notes: List[str] = []
model_tag = f"[{name}] " if name else ""
def _fail(msg: str) -> None:
full = model_tag + msg
if strict:
raise ValueError(full)
print("WARNING:", full)
notes.append(msg)
cif = np.asarray(cause_cif)
if cif.ndim != 3:
_fail(f"integrity: expected cause_cif ndim=3, got {cif.ndim}")
return False, notes
N, K, H = cif.shape
if H != len(horizons):
_fail(
f"integrity: horizon length mismatch (H={H}, len(horizons)={len(horizons)})")
# (5) Finite
if not np.isfinite(cif).all():
_fail("integrity: non-finite values (NaN/Inf) in cause_cif")
# (1) Range
cmin = float(np.nanmin(cif))
cmax = float(np.nanmax(cif))
if cmin < -tol:
_fail(f"integrity: range min={cmin} < -tol={-tol}")
if cmax > 1.0 + tol:
_fail(f"integrity: range max={cmax} > 1+tol={1.0+tol}")
# (2) Monotonicity in horizons (per n,k)
diffs = np.diff(cif, axis=2)
if diffs.size > 0:
if np.nanmin(diffs) < -tol:
_fail("integrity: monotonicity violated (found negative diff along horizons)")
# (3) Probability mass: sum_k CIF <= 1
mass = np.sum(cif, axis=1) # (N,H)
mass_max = float(np.nanmax(mass))
if mass_max > 1.0 + tol:
_fail(f"integrity: probability mass exceeds 1 (max sum_k={mass_max})")
# (4) Conservation with survival, if provided
if survival is None:
warn = "integrity: survival not provided; skipping conservation check"
if strict:
# still skip (requested behavior), but keep message for context
notes.append(warn)
else:
print("WARNING:", model_tag + warn)
notes.append(warn)
else:
s = np.asarray(survival, dtype=float)
if s.shape != (N, H):
_fail(
f"integrity: survival shape mismatch (got {s.shape}, expected {(N, H)})")
else:
recon = 1.0 - s
err = np.abs(recon - mass)
# Discrete-time should be very tight; exponential may accumulate slightly more numerical error.
tol_cons = max(float(tol), 1e-4)
if float(np.nanmax(err)) > tol_cons:
_fail(
f"integrity: conservation violated (max |(1-surv)-sum_cif|={float(np.nanmax(err))}, tol={tol_cons})")
ok = len([n for n in notes if not n.endswith(
"skipping conservation check")]) == 0
return ok, notes
# ============================================================
# Metrics
# ============================================================
# --- Rank-based ROC AUC (ties handled via midranks) ---
def compute_midrank(x: np.ndarray) -> np.ndarray:
"""Vectorized midrank computation (ties -> average ranks)."""
x = np.asarray(x, dtype=float)
n = int(x.shape[0])
if n == 0:
return np.asarray([], dtype=float)
order = np.argsort(x, kind="mergesort")
z = x[order]
# Find tie groups in sorted order.
diff = np.diff(z)
# boundaries includes 0 and n
boundaries = np.concatenate(
[np.array([0], dtype=int), np.nonzero(diff != 0)
[0] + 1, np.array([n], dtype=int)]
)
starts = boundaries[:-1]
ends = boundaries[1:]
lens = ends - starts
# Midrank for each group in 1-based rank space.
mids = 0.5 * (starts + ends - 1) + 1.0
t_sorted = np.repeat(mids, lens).astype(float, copy=False)
out = np.empty(n, dtype=float)
out[order] = t_sorted
return out
def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Rank-based ROC AUC via MannWhitney U statistic (ties handled by midranks).
Returns NaN for degenerate labels.
"""
y = np.asarray(y_true, dtype=int)
s = np.asarray(y_score, dtype=float)
if y.ndim != 1 or s.ndim != 1 or y.shape[0] != s.shape[0]:
raise ValueError("roc_auc_rank expects 1D arrays of equal length")
m = int(np.sum(y == 1))
n = int(np.sum(y == 0))
if m == 0 or n == 0:
return float("nan")
ranks = compute_midrank(s)
sum_pos = float(np.sum(ranks[y == 1]))
auc = (sum_pos - m * (m + 1) / 2.0) / (m * n)
return float(auc)
def brier_score(p: np.ndarray, y: np.ndarray) -> float:
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
return float(np.mean((p - y) ** 2))
def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]:
bins, ici = _calibration_bins_and_ici(
p, y, n_bins=int(n_bins), return_bins=True)
return {"bins": bins, "ici": float(ici)}
def calibration_ici_only(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> float:
"""Fast ICI only (no per-bin point export)."""
_, ici = _calibration_bins_and_ici(
p, y, n_bins=int(n_bins), return_bins=False)
return float(ici)
def _calibration_bins_and_ici(
p: np.ndarray,
y: np.ndarray,
*,
n_bins: int,
return_bins: bool,
) -> Tuple[List[Dict[str, Any]], float]:
"""Vectorized quantile binning for calibration + ICI."""
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
if p.size == 0:
return ([], float("nan")) if return_bins else ([], float("nan"))
q = np.linspace(0.0, 1.0, int(n_bins) + 1)
edges = np.quantile(p, q)
edges = np.asarray(edges, dtype=float)
edges[0] = -np.inf
edges[-1] = np.inf
# Bin assignment: i if edges[i] < p <= edges[i+1]
bin_idx = np.searchsorted(edges, p, side="right") - 1
bin_idx = np.clip(bin_idx, 0, int(n_bins) - 1)
counts = np.bincount(bin_idx, minlength=int(n_bins)).astype(float)
sum_p = np.bincount(bin_idx, weights=p,
minlength=int(n_bins)).astype(float)
sum_y = np.bincount(bin_idx, weights=y,
minlength=int(n_bins)).astype(float)
nonempty = counts > 0
if not np.any(nonempty):
return ([], float("nan")) if return_bins else ([], float("nan"))
p_mean = np.zeros(int(n_bins), dtype=float)
y_mean = np.zeros(int(n_bins), dtype=float)
p_mean[nonempty] = sum_p[nonempty] / counts[nonempty]
y_mean[nonempty] = sum_y[nonempty] / counts[nonempty]
diffs = np.abs(p_mean[nonempty] - y_mean[nonempty])
ici = float(np.mean(diffs)) if diffs.size else float("nan")
if not return_bins:
return [], ici
bins: List[Dict[str, Any]] = []
idxs = np.nonzero(nonempty)[0]
for i in idxs.tolist():
bins.append(
{
"bin": int(i),
"p_mean": float(p_mean[i]),
"y_mean": float(y_mean[i]),
"n": int(counts[i]),
}
)
return bins, ici
def _progress_line(done: int, total: int, prefix: str = "") -> str:
total_i = max(1, int(total))
done_i = max(0, min(int(done), total_i))
width = 28
frac = done_i / total_i
filled = int(round(width * frac))
bar = "#" * filled + "-" * (width - filled)
pct = 100.0 * frac
return f"{prefix}[{bar}] {done_i}/{total_i} ({pct:5.1f}%)"
def _should_show_progress(mode: str) -> bool:
m = str(mode).strip().lower()
if m in {"0", "false", "no", "none", "off"}:
return False
# Default: show if interactive.
if m in {"auto", "1", "true", "yes", "on", "bar"}:
try:
return bool(sys.stdout.isatty())
except Exception:
return True
return True
def _safe_float(x: Any, default: float = float("nan")) -> float:
try:
return float(x)
except Exception:
return float(default)
def _ensure_dir(path: str) -> None:
os.makedirs(path, exist_ok=True)
def load_cause_names(path: str = "labels.csv") -> Dict[int, str]:
"""Load 0-based cause_id -> name mapping.
labels.csv is assumed to be one label per line, in disease-id order.
"""
if not os.path.exists(path):
return {}
mapping: Dict[int, str] = {}
with open(path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
name = line.strip()
if name:
mapping[int(i)] = name
return mapping
def pick_focus_causes(
*,
counts_within_tau: Optional[np.ndarray],
n_disease: int,
death_cause_id: int = DEFAULT_DEATH_CAUSE_ID,
k: int = 5,
) -> List[int]:
"""Pick focus causes for user-facing evaluation.
Rule:
1) Always include death_cause_id first.
2) Then add K additional causes by descending event count if available.
If counts_within_tau is None, fall back to descending cause_id coverage proxy.
Notes:
- counts_within_tau is expected to be shape (n_disease,).
- Deterministic: ties broken by smaller cause id.
"""
n_disease_i = int(n_disease)
if death_cause_id < 0 or death_cause_id >= n_disease_i:
print(
f"WARNING: death_cause_id={death_cause_id} out of range (n_disease={n_disease_i}); "
"it will be omitted from focus causes."
)
focus: List[int] = []
else:
focus = [int(death_cause_id)]
candidates = [i for i in range(n_disease_i) if i != int(death_cause_id)]
if counts_within_tau is not None:
c = np.asarray(counts_within_tau).astype(float)
if c.shape[0] != n_disease_i:
print(
"WARNING: counts_within_tau length mismatch; falling back to coverage proxy ordering."
)
counts_within_tau = None
else:
# Sort by (-count, cause_id)
order = sorted(candidates, key=lambda i: (-float(c[i]), int(i)))
order = [i for i in order if float(c[i]) > 0]
focus.extend([int(i) for i in order[: int(k)]])
if counts_within_tau is None:
# Fallback: deterministic coverage proxy (descending id, excluding death), then take K.
# (Real coverage requires data; this path is mostly for robustness.)
order = sorted(candidates, key=lambda i: (-int(i)))
focus.extend([int(i) for i in order[: int(k)]])
# De-dup while preserving order
seen = set()
out: List[int] = []
for cid in focus:
if cid not in seen:
out.append(cid)
seen.add(cid)
return out
def write_simple_csv(path: str, fieldnames: List[str], rows: List[Dict[str, Any]]) -> None:
_ensure_dir(os.path.dirname(os.path.abspath(path)) or ".")
with open(path, "w", newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow(r)
def _sex_slices(sex: Optional[np.ndarray]) -> List[Tuple[str, Optional[np.ndarray]]]:
"""Return list of (sex_label, mask) slices including an 'all' slice.
If sex is missing, returns only ('all', None).
"""
out: List[Tuple[str, Optional[np.ndarray]]] = [("all", None)]
if sex is None:
return out
s = np.asarray(sex)
if s.ndim != 1:
return out
for val in [0, 1]:
m = (s == val)
if int(np.sum(m)) > 0:
out.append((str(val), m))
return out
def _quantile_edges(p: np.ndarray, q: int) -> np.ndarray:
edges = np.quantile(p, np.linspace(0.0, 1.0, int(q) + 1))
edges = np.asarray(edges, dtype=float)
edges[0] = -np.inf
edges[-1] = np.inf
return edges
def compute_risk_stratification_bins(
p: np.ndarray,
y: np.ndarray,
*,
q_default: int = 10,
) -> Tuple[int, List[Dict[str, Any]], Dict[str, Any]]:
"""Compute quantile-based risk strata and a compact summary."""
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
n = int(p.shape[0])
if n == 0:
return 0, [], {
"y_overall": float("nan"),
"top_decile_y_rate": float("nan"),
"bottom_half_y_rate": float("nan"),
"lift_top10_vs_bottom50": float("nan"),
"slope_pred_vs_obs": float("nan"),
}
# Choose quantiles robustly.
q = int(q_default)
if n < 200:
q = 5
edges = _quantile_edges(p, q)
y_overall = float(np.mean(y))
bin_rows: List[Dict[str, Any]] = []
p_means: List[float] = []
y_rates: List[float] = []
n_bins: List[int] = []
for i in range(q):
mask = (p > edges[i]) & (p <= edges[i + 1])
nb = int(np.sum(mask))
if nb == 0:
# Keep the row for consistent plotting; set NaNs.
bin_rows.append(
{
"q": int(i + 1),
"n_bin": 0,
"p_mean": float("nan"),
"y_rate": float("nan"),
"y_overall": y_overall,
"lift_vs_overall": float("nan"),
}
)
continue
p_mean = float(np.mean(p[mask]))
y_rate = float(np.mean(y[mask]))
lift = float(y_rate / y_overall) if y_overall > 0 else float("nan")
bin_rows.append(
{
"q": int(i + 1),
"n_bin": nb,
"p_mean": p_mean,
"y_rate": y_rate,
"y_overall": y_overall,
"lift_vs_overall": lift,
}
)
p_means.append(p_mean)
y_rates.append(y_rate)
n_bins.append(nb)
# Summary
top_mask = (p > edges[q - 1]) & (p <= edges[q])
bot_half_mask = (p > edges[0]) & (p <= edges[q // 2])
top_y = float(np.mean(y[top_mask])) if int(
np.sum(top_mask)) > 0 else float("nan")
bot_y = float(np.mean(y[bot_half_mask])) if int(
np.sum(bot_half_mask)) > 0 else float("nan")
lift_top_vs_bottom = float(top_y / bot_y) if (np.isfinite(top_y)
and np.isfinite(bot_y) and bot_y > 0) else float("nan")
slope = float("nan")
if len(p_means) >= 2:
# Weighted least squares slope of y_rate ~ p_mean.
x = np.asarray(p_means, dtype=float)
yy = np.asarray(y_rates, dtype=float)
w = np.asarray(n_bins, dtype=float)
xm = float(np.average(x, weights=w))
ym = float(np.average(yy, weights=w))
denom = float(np.sum(w * (x - xm) ** 2))
if denom > 0:
slope = float(np.sum(w * (x - xm) * (yy - ym)) / denom)
summary = {
"y_overall": y_overall,
"top_decile_y_rate": top_y,
"bottom_half_y_rate": bot_y,
"lift_top10_vs_bottom50": lift_top_vs_bottom,
"slope_pred_vs_obs": slope,
}
return q, bin_rows, summary
def compute_capture_points(
p: np.ndarray,
y: np.ndarray,
k_pcts: Sequence[int],
) -> List[Dict[str, Any]]:
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
n = int(p.shape[0])
if n == 0:
return []
order = np.argsort(-p)
y_sorted = y[order]
events_total = float(np.sum(y_sorted))
rows: List[Dict[str, Any]] = []
for k in k_pcts:
kf = float(k)
n_targeted = int(math.ceil(n * kf / 100.0))
n_targeted = max(1, min(n_targeted, n))
events_targeted = float(np.sum(y_sorted[:n_targeted]))
capture = float(events_targeted /
events_total) if events_total > 0 else float("nan")
precision = float(events_targeted / float(n_targeted))
rows.append(
{
"k_pct": int(k),
"n_targeted": int(n_targeted),
"events_targeted": float(events_targeted),
"events_total": float(events_total),
"event_capture_rate": capture,
"precision_in_targeted": precision,
}
)
return rows
def compute_event_rate_at_topk_causes(
p_tau: np.ndarray,
y_tau: np.ndarray,
topk_list: Sequence[int],
) -> List[Dict[str, Any]]:
"""Compute Event Rate@K for cross-cause prioritization.
For each individual, rank causes by predicted risk p_tau at a fixed horizon.
For each K, select top-K causes and compute the fraction that occur within the horizon.
Args:
p_tau: (N, K) predicted CIFs at a fixed horizon
y_tau: (N, K) binary labels (0/1) whether cause occurs within the horizon
topk_list: list of K values to evaluate
Returns:
List of rows with:
- topk
- event_rate_mean / event_rate_median
- recall_mean / recall_median (averaged over individuals with >=1 true cause)
- n_total / n_valid_recall
"""
p = np.asarray(p_tau, dtype=float)
y = np.asarray(y_tau, dtype=float)
if p.ndim != 2 or y.ndim != 2 or p.shape != y.shape:
raise ValueError(
"compute_event_rate_at_topk_causes expects (N,K) arrays of equal shape")
n, k_total = p.shape
if n == 0 or k_total == 0:
out: List[Dict[str, Any]] = []
for kk in topk_list:
out.append(
{
"topk": int(max(1, int(kk))),
"event_rate_mean": float("nan"),
"event_rate_median": float("nan"),
"recall_mean": float("nan"),
"recall_median": float("nan"),
"n_total": int(n),
"n_valid_recall": 0,
}
)
return out
# Sanitize K list.
topks = sorted({int(x) for x in topk_list if int(x) > 0})
if not topks:
return []
max_k = min(int(max(topks)), int(k_total))
if max_k <= 0:
return []
# Efficient: get top max_k causes per individual, then sort within those.
part = np.argpartition(-p, kth=max_k - 1, axis=1)[:, :max_k] # (N, max_k)
p_part = np.take_along_axis(p, part, axis=1)
order = np.argsort(-p_part, axis=1)
top_sorted = np.take_along_axis(part, order, axis=1) # (N, max_k)
out_rows: List[Dict[str, Any]] = []
for kk in topks:
kk_eff = min(int(kk), int(k_total))
idx = top_sorted[:, :kk_eff]
y_sel = np.take_along_axis(y, idx, axis=1)
# Selected true causes per person
hit = np.sum(y_sel, axis=1)
# Precision-like: fraction of selected causes that occur
per_person = hit / \
float(kk_eff) if kk_eff > 0 else np.full((n,), np.nan)
# Recall@K: fraction of true causes covered by top-K (undefined when no true cause)
g = np.sum(y, axis=1)
valid = g > 0
recall = np.full((n,), np.nan, dtype=float)
recall[valid] = hit[valid] / g[valid]
out_rows.append(
{
"topk": int(kk_eff),
"event_rate_mean": float(np.mean(per_person)) if per_person.size else float("nan"),
"event_rate_median": float(np.median(per_person)) if per_person.size else float("nan"),
"recall_mean": float(np.nanmean(recall)) if int(np.sum(valid)) > 0 else float("nan"),
"recall_median": float(np.nanmedian(recall)) if int(np.sum(valid)) > 0 else float("nan"),
"n_total": int(n),
"n_valid_recall": int(np.sum(valid)),
}
)
return out_rows
def compute_random_ranking_baseline_topk(
y_tau: np.ndarray,
topk_list: Sequence[int],
*,
z: float = 1.645,
) -> List[Dict[str, Any]]:
"""Random ranking baseline for Event Rate@K and Recall@K.
Baseline definition:
- For each individual, pick K causes uniformly at random without replacement.
- EventRate@K = (# selected causes that occur) / K.
- Recall@K = (# selected causes that occur) / (# causes that occur), averaged over individuals with >=1 true cause.
This function computes the expected baseline mean and an approximate 5-95% range
for the population mean using a normal approximation of the hypergeometric variance.
Args:
y_tau: (N, K_total) binary labels
topk_list: K values
z: z-score for the central interval; z=1.645 corresponds to ~90% (5-95%)
Returns:
Rows with baseline means and p05/p95 for both metrics.
"""
y = np.asarray(y_tau, dtype=float)
if y.ndim != 2:
raise ValueError(
"compute_random_ranking_baseline_topk expects y_tau with shape (N,K)")
n, k_total = y.shape
topks = sorted({int(x) for x in topk_list if int(x) > 0})
if not topks:
return []
g = np.sum(y, axis=1) # (N,)
valid = g > 0
n_valid = int(np.sum(valid))
out: List[Dict[str, Any]] = []
for kk in topks:
kk_eff = min(int(kk), int(k_total)) if k_total > 0 else int(kk)
if n == 0 or k_total == 0 or kk_eff <= 0:
out.append(
{
"topk": int(max(1, kk_eff)),
"baseline_event_rate_mean": float("nan"),
"baseline_event_rate_p05": float("nan"),
"baseline_event_rate_p95": float("nan"),
"baseline_recall_mean": float("nan"),
"baseline_recall_p05": float("nan"),
"baseline_recall_p95": float("nan"),
"n_total": int(n),
"n_valid_recall": int(n_valid),
"k_total": int(k_total),
"baseline_method": "random_ranking_hypergeometric_normal_approx",
}
)
continue
# Expected EventRate@K per person is E[X]/K = (K * (g/K_total))/K = g/K_total.
er_mean = float(np.mean(g / float(k_total)))
# Variance of hypergeometric count X:
# Var(X) = K * p * (1-p) * ((K_total - K)/(K_total - 1)), where p=g/K_total.
if k_total > 1 and kk_eff < k_total:
p = g / float(k_total)
finite_corr = (float(k_total - kk_eff) / float(k_total - 1))
var_x = float(kk_eff) * p * (1.0 - p) * finite_corr
else:
var_x = np.zeros_like(g, dtype=float)
var_er = var_x / (float(kk_eff) ** 2)
se_er_mean = float(np.sqrt(np.sum(var_er))) / float(max(1, n))
er_p05 = float(np.clip(er_mean - z * se_er_mean, 0.0, 1.0))
er_p95 = float(np.clip(er_mean + z * se_er_mean, 0.0, 1.0))
# Expected Recall@K for individuals with g>0 is K/K_total (clipped).
rec_mean = float(min(float(kk_eff) / float(k_total), 1.0))
if n_valid > 0:
var_rec = np.zeros_like(g, dtype=float)
gv = g[valid]
var_xv = var_x[valid]
# Var( X / g ) = Var(X) / g^2 (approx; g is fixed per individual)
var_rec_v = var_xv / (gv ** 2)
se_rec_mean = float(np.sqrt(np.sum(var_rec_v))) / float(n_valid)
rec_p05 = float(np.clip(rec_mean - z * se_rec_mean, 0.0, 1.0))
rec_p95 = float(np.clip(rec_mean + z * se_rec_mean, 0.0, 1.0))
else:
rec_p05 = float("nan")
rec_p95 = float("nan")
out.append(
{
"topk": int(kk_eff),
"baseline_event_rate_mean": er_mean,
"baseline_event_rate_p05": er_p05,
"baseline_event_rate_p95": er_p95,
"baseline_recall_mean": rec_mean,
"baseline_recall_p05": float(rec_p05),
"baseline_recall_p95": float(rec_p95),
"n_total": int(n),
"n_valid_recall": int(n_valid),
"k_total": int(k_total),
"baseline_method": "random_ranking_hypergeometric_normal_approx",
}
)
return out
def count_occurs_within_horizon(
loader: DataLoader,
offset_years: float,
tau_years: float,
n_disease: int,
device: str,
) -> Tuple[np.ndarray, int]:
"""Count per-person occurrence within tau after the prediction context.
Returns counts[k] = number of individuals with disease k at least once in (t_ctx, t_ctx+tau].
"""
counts = torch.zeros((n_disease,), dtype=torch.long, device=device)
n_total_eval = 0
for batch in loader:
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
keep, t_ctx, _ = select_context_indices(
event_seq, time_seq, offset_years)
if not keep.any():
continue
n_total_eval += int(keep.sum().item())
event_seq = event_seq[keep]
time_seq = time_seq[keep]
t_ctx = t_ctx[keep]
B, L = event_seq.shape
b = torch.arange(B, device=device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (float(tau_years) * DAYS_PER_YEAR)
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
if not in_window.any():
continue
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
# unique per (person, disease) to count per-person within-window occurrence
key = b_idx.to(torch.long) * int(n_disease) + disease_ids
uniq = torch.unique(key)
uniq_disease = uniq % int(n_disease)
counts.scatter_add_(0, uniq_disease, torch.ones_like(
uniq_disease, dtype=torch.long))
return counts.detach().cpu().numpy(), int(n_total_eval)
# ============================================================
# Evaluation core
# ============================================================
def instantiate_model_and_head(
cfg: Dict[str, Any],
dataset: HealthDataset,
device: str,
checkpoint_path: str = "",
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float], Dict[str, Any]]:
model_type = str(cfg["model_type"])
loss_type = str(cfg["loss_type"])
loss_params: Dict[str, Any] = {}
if loss_type == "exponential":
out_dims = [dataset.n_disease]
elif loss_type == "discrete_time_cif":
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
out_dims = [dataset.n_disease + 1, len(bin_edges)]
elif loss_type == "lognormal_basis_binned_hazard_cif":
centers = cfg.get("lognormal_centers", None)
if centers is None:
centers = cfg.get("centers", None)
if not isinstance(centers, list) or len(centers) == 0:
raise ValueError(
"lognormal_basis_binned_hazard_cif requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json"
)
r = len(centers)
desired_total = int(dataset.n_disease) * int(r)
legacy_total = 1 + desired_total
# Prefer the new shape (K,R) but keep compatibility with older checkpoints
# that used a single flattened dimension (1 + K*R).
out_dims = [int(dataset.n_disease), int(r)]
if checkpoint_path:
try:
ckpt = torch.load(checkpoint_path, map_location="cpu")
head_sd = ckpt.get("head_state_dict", {})
w = head_sd.get("net.2.weight", None)
if isinstance(w, torch.Tensor) and w.ndim == 2:
out_features = int(w.shape[0])
if out_features == legacy_total:
out_dims = [legacy_total]
elif out_features == desired_total:
out_dims = [int(dataset.n_disease), int(r)]
else:
raise ValueError(
f"Checkpoint head out_features={out_features} does not match expected {desired_total} (K*R) or {legacy_total} (1+K*R)"
)
except Exception as e:
raise ValueError(
f"Failed to infer head output dims from checkpoint={checkpoint_path}: {e}"
)
loss_params["centers"] = centers
loss_params["bandwidth_min"] = float(cfg.get("bandwidth_min", 1e-3))
loss_params["bandwidth_max"] = float(cfg.get("bandwidth_max", 10.0))
loss_params["bandwidth_init"] = float(cfg.get("bandwidth_init", 0.7))
loss_params["loss_eps"] = float(cfg.get("loss_eps", 1e-8))
loss_params["alpha_floor"] = float(cfg.get("alpha_floor", 0.0))
else:
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
if model_type == "delphi_fork":
backbone = DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
).to(device)
elif model_type == "sap_delphi":
# Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path.
emb_path = cfg.get("pretrained_emb_path", None)
if emb_path in {"", None}:
emb_path = cfg.get("pretrained_emd_path", None)
if emb_path in {"", None}:
run_dir = os.path.dirname(os.path.abspath(
checkpoint_path)) if checkpoint_path else ""
print(
f"WARNING: SapDelphi pretrained embedding path missing in config "
f"(expected 'pretrained_emb_path' or 'pretrained_emd_path'). "
f"checkpoint={checkpoint_path} run_dir={run_dir}"
)
backbone = SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
pretrained_weights_path=emb_path,
freeze_embeddings=True,
).to(device)
else:
raise ValueError(f"Unsupported model_type: {model_type}")
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
return backbone, head, loss_type, bin_edges, loss_params
@torch.no_grad()
def predict_cifs_for_model(
backbone: torch.nn.Module,
head: torch.nn.Module,
loss_type: str,
bin_edges: Sequence[float],
loss_params: Dict[str, Any],
loader: DataLoader,
device: str,
offset_years: float,
eval_horizons: Sequence[float],
n_disease: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Run model and produce cause-specific, time-dependent CIF outputs.
Returns:
cif_full: (N, K, H)
survival: (N, H)
y_cause_within_tau: (N, K, H)
NOTE: Evaluation is cause-specific and horizon-specific (multi-disease risk).
"""
backbone.eval()
head.eval()
# We will accumulate in CPU lists, then concat.
cif_full_list: List[np.ndarray] = []
survival_list: List[np.ndarray] = []
y_cause_within_list: List[np.ndarray] = []
sex_list: List[np.ndarray] = []
for batch in loader:
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
keep, t_ctx, _ = select_context_indices(
event_seq, time_seq, offset_years)
if not keep.any():
continue
# filter batch
event_seq = event_seq[keep]
time_seq = time_seq[keep]
cont_feats = cont_feats[keep]
cate_feats = cate_feats[keep]
sexes_k = sexes[keep]
t_ctx = t_ctx[keep]
h = backbone(event_seq, time_seq, sexes_k,
cont_feats, cate_feats) # (B,L,D)
b = torch.arange(h.size(0), device=device)
c = h[b, t_ctx] # (B,D)
logits = head(c)
if loss_type == "exponential":
cif_full, survival = cifs_from_exponential_logits(
logits, eval_horizons, return_survival=True) # (B,K,H), (B,H)
elif loss_type == "discrete_time_cif":
cif_full, survival = cifs_from_discrete_time_logits(
# (B,K,H), (B,H)
logits, bin_edges, eval_horizons, return_survival=True)
elif loss_type == "lognormal_basis_binned_hazard_cif":
centers = loss_params.get("centers", None)
sigma = loss_params.get("sigma", None)
if centers is None or sigma is None:
raise ValueError(
"lognormal_basis_binned_hazard_cif requires loss_params['centers'] and loss_params['sigma']")
cif_full, survival = cifs_from_lognormal_basis_binned_hazard_logits(
logits,
centers=centers,
sigma=sigma,
bin_edges=bin_edges,
taus=eval_horizons,
eps=float(loss_params.get("loss_eps", 1e-8)),
alpha_floor=float(loss_params.get("alpha_floor", 0.0)),
return_survival=True,
)
else:
raise ValueError(f"Unsupported loss_type: {loss_type}")
# Within-horizon labels for all causes: disease k occurs within tau after context.
y_within_full = torch.stack(
[
multi_hot_ever_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau),
n_disease=int(n_disease),
).to(torch.float32)
for tau in eval_horizons
],
dim=2,
) # (B,K,H)
cif_full_list.append(cif_full.detach().cpu().numpy())
survival_list.append(survival.detach().cpu().numpy())
y_cause_within_list.append(y_within_full.detach().cpu().numpy())
sex_list.append(sexes_k.detach().cpu().numpy())
if not cif_full_list:
raise RuntimeError(
"No valid samples for evaluation (all batches filtered out by offset).")
cif_full = np.concatenate(cif_full_list, axis=0)
survival = np.concatenate(survival_list, axis=0)
y_cause_within = np.concatenate(y_cause_within_list, axis=0)
sex = np.concatenate(
sex_list, axis=0) if sex_list else np.array([], dtype=int)
return cif_full, survival, y_cause_within, sex
def evaluate_one_model(
model_name: str,
cif_full: np.ndarray,
y_cause_within_tau: np.ndarray,
eval_horizons: Sequence[float],
out_rows: List[Dict[str, Any]],
calib_rows: List[Dict[str, Any]],
calib_cause_ids: Optional[Sequence[int]],
n_calib_bins: int = 10,
metric_workers: int = 0,
progress: str = "auto",
) -> None:
"""Compute per-cause metrics for ALL diseases.
Notes:
- Writes scalar metrics for all causes into out_rows.
- Writes calibration-bin points only for calib_cause_ids (to keep outputs tractable).
"""
cif_full = np.asarray(cif_full, dtype=float)
y_cause_within_tau = np.asarray(y_cause_within_tau, dtype=float)
if cif_full.ndim != 3 or y_cause_within_tau.ndim != 3:
raise ValueError(
"Expected cif_full and y_cause_within_tau with shape (N, K, H)")
if cif_full.shape != y_cause_within_tau.shape:
raise ValueError(
f"Shape mismatch: cif_full {cif_full.shape} vs y_cause_within_tau {y_cause_within_tau.shape}"
)
N, K, H = cif_full.shape
if H != len(eval_horizons):
raise ValueError("H mismatch between cif_full and eval_horizons")
calib_set = set(int(x)
for x in calib_cause_ids) if calib_cause_ids is not None else set()
workers = int(metric_workers)
if workers <= 0:
workers = int(min(8, os.cpu_count() or 1))
workers = max(1, workers)
show_progress = _should_show_progress(progress)
def _eval_chunk(
*,
tau: float,
p_tau: np.ndarray,
y_tau: np.ndarray,
brier_by_cause: np.ndarray,
cause_ids: np.ndarray,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], int]:
local_rows: List[Dict[str, Any]] = []
local_calib: List[Dict[str, Any]] = []
for cid in cause_ids.tolist():
p = p_tau[:, cid]
y = y_tau[:, cid]
local_rows.append(
{
"model_name": model_name,
"metric_name": "cause_brier",
"horizon": float(tau),
"cause": int(cid),
"value": float(brier_by_cause[cid]),
"ci_low": "",
"ci_high": "",
}
)
# ICI: compute bins only if we will export them.
need_bins = (not calib_set) or (int(cid) in calib_set)
if need_bins:
cal = calibration_deciles(p, y, n_bins=n_calib_bins)
ici = float(cal["ici"])
else:
cal = None
ici = calibration_ici_only(p, y, n_bins=n_calib_bins)
local_rows.append(
{
"model_name": model_name,
"metric_name": "cause_ici",
"horizon": float(tau),
"cause": int(cid),
"value": float(ici),
"ci_low": "",
"ci_high": "",
}
)
# Secondary: discrimination via AUC at the same horizon (point estimate only).
auc = roc_auc_rank(y, p)
local_rows.append(
{
"model_name": model_name,
"metric_name": "cause_auc",
"horizon": float(tau),
"cause": int(cid),
"value": float(auc),
"ci_low": "",
"ci_high": "",
}
)
if need_bins and cal is not None:
for binfo in cal.get("bins", []):
local_calib.append(
{
"model_name": model_name,
"task": "cause_k",
"horizon": float(tau),
"cause_id": int(cid),
"bin_index": int(binfo["bin"]),
"p_mean": float(binfo["p_mean"]),
"y_mean": float(binfo["y_mean"]),
"n_in_bin": int(binfo["n"]),
}
)
return local_rows, local_calib, int(cause_ids.size)
# Cause-specific, time-dependent metrics per horizon.
for h_i, tau in enumerate(eval_horizons):
p_tau = cif_full[:, :, h_i] # (N, K)
y_tau = y_cause_within_tau[:, :, h_i] # (N, K)
# Vectorized Brier for speed.
brier_by_cause = np.mean((p_tau - y_tau) ** 2, axis=0) # (K,)
# Parallelize disease-level metrics; chunk to avoid millions of futures.
all_ids = np.arange(int(K), dtype=int)
chunks = np.array_split(all_ids, workers)
done = 0
prefix = f"[{model_name}] tau={float(tau)}y "
t0 = time.time()
if workers <= 1:
for ch in chunks:
r_chunk, c_chunk, n_done = _eval_chunk(
tau=float(tau),
p_tau=p_tau,
y_tau=y_tau,
brier_by_cause=brier_by_cause,
cause_ids=ch,
)
out_rows.extend(r_chunk)
calib_rows.extend(c_chunk)
done += int(n_done)
if show_progress:
sys.stdout.write(
"\r" + _progress_line(done, int(K), prefix=prefix))
sys.stdout.flush()
else:
with ThreadPoolExecutor(max_workers=workers) as ex:
futs = [
ex.submit(
_eval_chunk,
tau=float(tau),
p_tau=p_tau,
y_tau=y_tau,
brier_by_cause=brier_by_cause,
cause_ids=ch,
)
for ch in chunks
if int(ch.size) > 0
]
for fut in as_completed(futs):
r_chunk, c_chunk, n_done = fut.result()
out_rows.extend(r_chunk)
calib_rows.extend(c_chunk)
done += int(n_done)
if show_progress:
sys.stdout.write(
"\r" + _progress_line(done, int(K), prefix=prefix))
sys.stdout.flush()
if show_progress:
dt = time.time() - t0
sys.stdout.write("\r" + _progress_line(int(K),
int(K), prefix=prefix) + f" ({dt:.1f}s)\n")
sys.stdout.flush()
def summarize_over_diseases(
rows: List[Dict[str, Any]],
*,
model_name: str,
eval_horizons: Sequence[float],
metrics: Sequence[str] = ("cause_brier", "cause_ici", "cause_auc"),
) -> List[Dict[str, Any]]:
"""Summarize mean/median of each metric over diseases (per horizon)."""
out: List[Dict[str, Any]] = []
# Build metric_name -> horizon -> list of values
bucket: Dict[Tuple[str, float], List[float]] = {}
for r in rows:
if r.get("model_name") != model_name:
continue
m = str(r.get("metric_name"))
if m not in set(metrics):
continue
h = _safe_float(r.get("horizon"))
v = _safe_float(r.get("value"))
if not np.isfinite(h):
continue
if not np.isfinite(v):
continue
bucket.setdefault((m, float(h)), []).append(float(v))
for tau in eval_horizons:
ht = float(tau)
for m in metrics:
vals = bucket.get((str(m), ht), [])
if vals:
arr = np.asarray(vals, dtype=float)
mean_v = float(np.mean(arr))
med_v = float(np.median(arr))
n_valid = int(arr.size)
else:
mean_v = float("nan")
med_v = float("nan")
n_valid = 0
out.append(
{
"model_name": str(model_name),
"metric_name": str(m),
"horizon": ht,
"mean": mean_v,
"median": med_v,
"n_valid": n_valid,
}
)
return out
def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None:
fieldnames = [
"model_name",
"task",
"horizon",
"cause_id",
"bin_index",
"p_mean",
"y_mean",
"n_in_bin",
]
with open(path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow(r)
def write_results_csv(path: str, rows: List[Dict[str, Any]]) -> None:
fieldnames = [
"model_name",
"metric_name",
"horizon",
"cause",
"value",
"ci_low",
"ci_high",
]
with open(path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow(r)
def _make_eval_tag(split: str, offset_years: float) -> str:
"""Short tag for filenames written into run directories."""
off = f"{float(offset_years):.4f}".rstrip("0").rstrip(".")
return f"{split}_offset{off}y"
def main() -> int:
ap = argparse.ArgumentParser(
description="Unified downstream evaluation via CIFs")
ap.add_argument("--models_json", type=str, required=True,
help="Path to models list JSON")
ap.add_argument("--data_prefix", type=str,
default="ukb", help="Dataset prefix")
ap.add_argument("--split", type=str, default="test",
choices=["train", "val", "test", "all"], help="Which split to evaluate")
ap.add_argument("--offset_years", type=float, default=0.5,
help="Anti-leakage offset (years)")
ap.add_argument("--eval_horizons", type=float,
nargs="*", default=DEFAULT_EVAL_HORIZONS)
ap.add_argument("--batch_size", type=int, default=128)
ap.add_argument("--num_workers", type=int, default=0)
ap.add_argument("--seed", type=int, default=123)
ap.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
ap.add_argument("--out_csv", type=str, default="eval_summary.csv")
ap.add_argument("--out_meta_json", type=str, default="eval_meta.json")
# Integrity checks
ap.add_argument("--integrity_strict", action="store_true", default=False)
ap.add_argument("--integrity_tol", type=float, default=1e-6)
# Speed/UX
ap.add_argument(
"--metric_workers",
type=int,
default=0,
help="Threads for per-disease metrics (0=auto, 1=disable parallelism)",
)
ap.add_argument(
"--progress",
type=str,
default="auto",
choices=["auto", "bar", "none"],
help="Progress visualization during per-disease evaluation",
)
# Export settings for user-facing experiments
ap.add_argument("--export_dir", type=str, default="eval_exports")
ap.add_argument("--death_cause_id", type=int,
default=DEFAULT_DEATH_CAUSE_ID)
ap.add_argument("--focus_k", type=int, default=5,
help="Additional non-death causes to include")
ap.add_argument("--capture_k_pcts", type=int,
nargs="*", default=[1, 5, 10, 20])
ap.add_argument(
"--capture_curve_max_pct",
type=int,
default=50,
help="If >0, also export a dense capture curve for k=1..max_pct",
)
# High-risk cause concentration (cross-cause prioritization)
ap.add_argument(
"--cause_concentration_topk",
type=int,
nargs="*",
default=[5, 10, 20, 50],
help="Top-K causes per individual for Event Rate@K (cross-cause prioritization)",
)
ap.add_argument(
"--cause_concentration_write_random_baseline",
action="store_true",
default=False,
help="If set, also export a random-ranking baseline (expected Event Rate@K and Recall@K with an uncertainty range)",
)
args = ap.parse_args()
set_deterministic(args.seed)
specs = load_models_json(args.models_json)
if not specs:
raise ValueError("No models provided")
export_dir = str(args.export_dir)
_ensure_dir(export_dir)
cause_names = load_cause_names("labels.csv")
# Determine top-K causes from the evaluation split only (model-agnostic),
# aligned to time-dependent risk: occurrence within tau_max after context.
first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path)
cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [
"bmi", "smoking", "alcohol"]
dataset_for_top = HealthDataset(
data_prefix=args.data_prefix, covariate_list=cov_list)
subset_for_top = build_eval_subset(
dataset_for_top,
train_ratio=float(first_cfg.get("train_ratio", 0.7)),
val_ratio=float(first_cfg.get("val_ratio", 0.15)),
seed=int(first_cfg.get("random_seed", 42)),
split=args.split,
)
loader_top = DataLoader(
subset_for_top,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=health_collate_fn,
)
tau_max = float(max(args.eval_horizons))
counts, n_total_eval = count_occurs_within_horizon(
loader=loader_top,
offset_years=args.offset_years,
tau_years=tau_max,
n_disease=dataset_for_top.n_disease,
device=args.device,
)
focus_causes = pick_focus_causes(
counts_within_tau=counts,
n_disease=int(dataset_for_top.n_disease),
death_cause_id=int(args.death_cause_id),
k=int(args.focus_k),
)
top_cause_ids = np.asarray(focus_causes, dtype=int)
# Export the chosen focus causes.
focus_rows: List[Dict[str, Any]] = []
for r, cid in enumerate(focus_causes, start=1):
row: Dict[str, Any] = {"cause": int(cid), "rank": int(r)}
if cid in cause_names:
row["cause_name"] = cause_names[cid]
focus_rows.append(row)
focus_fieldnames = ["cause", "rank"] + \
(["cause_name"] if any("cause_name" in r for r in focus_rows) else [])
write_simple_csv(os.path.join(export_dir, "focus_causes.csv"),
focus_fieldnames, focus_rows)
# Metadata for focus causes (within tau_max).
top_causes_meta: List[Dict[str, Any]] = []
for cid in focus_causes:
n_case = int(counts[int(cid)]) if int(
cid) < int(counts.shape[0]) else 0
top_causes_meta.append(
{
"cause_id": int(cid),
"tau_years": float(tau_max),
"n_case_within_tau": n_case,
"n_control_within_tau": int(n_total_eval - n_case),
"n_total_eval": int(n_total_eval),
}
)
summary_rows: List[Dict[str, Any]] = []
calib_rows: List[Dict[str, Any]] = []
# Experiment exports (accumulated across models)
cap_points_rows: List[Dict[str, Any]] = []
cap_curve_rows: List[Dict[str, Any]] = []
conc_rows: List[Dict[str, Any]] = []
conc_base_rows: List[Dict[str, Any]] = []
# Track per-model integrity status for meta JSON.
integrity_meta: Dict[str, Any] = {}
# Evaluate each model
for spec in specs:
run_dir = os.path.dirname(os.path.abspath(spec.checkpoint_path))
tag = _make_eval_tag(args.split, float(args.offset_years))
# Remember list offsets so we can write per-model slices to the model's run_dir.
calib_start = len(calib_rows)
cfg = load_train_config_for_checkpoint(spec.checkpoint_path)
# Identifiers for consistent exports
model_id = str(spec.name)
model_type = str(
cfg.get("model_type", spec.model_type if hasattr(spec, "model_type") else ""))
loss_type_id = str(
cfg.get("loss_type", spec.loss_type if hasattr(spec, "loss_type") else ""))
age_encoder = str(cfg.get("age_encoder", ""))
cov_type = "full" if _parse_bool(
cfg.get("full_cov", False)) else "partial"
cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [
"bmi", "smoking", "alcohol"]
dataset = HealthDataset(
data_prefix=args.data_prefix, covariate_list=cov_list)
subset = build_eval_subset(
dataset,
train_ratio=float(cfg.get("train_ratio", 0.7)),
val_ratio=float(cfg.get("val_ratio", 0.15)),
seed=int(cfg.get("random_seed", 42)),
split=args.split,
)
loader = DataLoader(
subset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=health_collate_fn,
)
ckpt = torch.load(spec.checkpoint_path, map_location=args.device)
backbone, head, loss_type, bin_edges, loss_params = instantiate_model_and_head(
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
head.load_state_dict(ckpt["head_state_dict"], strict=True)
if loss_type == "lognormal_basis_binned_hazard_cif":
crit_state = ckpt.get("criterion_state_dict", {})
log_sigma = crit_state.get("log_sigma", None)
if isinstance(log_sigma, torch.Tensor):
log_sigma_t = log_sigma.to(device=args.device)
sigma = torch.exp(log_sigma_t)
else:
sigma = torch.tensor(float(loss_params.get(
"bandwidth_init", 0.7)), device=args.device)
bmin = float(loss_params.get("bandwidth_min", 1e-3))
bmax = float(loss_params.get("bandwidth_max", 10.0))
sigma = torch.clamp(sigma, min=bmin, max=bmax)
loss_params["sigma"] = sigma
(
cif_full,
survival,
y_cause_within_tau,
sex,
) = predict_cifs_for_model(
backbone,
head,
loss_type,
bin_edges,
loss_params,
loader,
args.device,
args.offset_years,
args.eval_horizons,
n_disease=int(dataset.n_disease),
)
# CIF integrity checks before metrics.
integrity_ok, integrity_notes = check_cif_integrity(
cif_full,
args.eval_horizons,
tol=float(args.integrity_tol),
name=spec.name,
strict=bool(args.integrity_strict),
survival=survival,
)
integrity_meta[spec.name] = {
"integrity_ok": bool(integrity_ok),
"integrity_notes": integrity_notes,
}
# Per-disease metrics for ALL diseases (written into the model's run_dir).
model_rows: List[Dict[str, Any]] = []
evaluate_one_model(
model_name=spec.name,
cif_full=cif_full,
y_cause_within_tau=y_cause_within_tau,
eval_horizons=args.eval_horizons,
out_rows=model_rows,
calib_rows=calib_rows,
calib_cause_ids=top_cause_ids.tolist(),
metric_workers=int(args.metric_workers),
progress=str(args.progress),
)
# Summary over diseases (mean/median per horizon).
model_summary_rows = summarize_over_diseases(
model_rows,
model_name=spec.name,
eval_horizons=args.eval_horizons,
)
summary_rows.extend(model_summary_rows)
# ============================================================
# Experiment: High-Risk Cause Concentration at fixed horizon
# (cross-cause prioritization accuracy)
# ============================================================
topk_causes = [int(x) for x in args.cause_concentration_topk]
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
for h_i, tau in enumerate(args.eval_horizons):
p_tau_all = np.asarray(cif_full[:, :, h_i], dtype=float)
y_tau_all = np.asarray(
y_cause_within_tau[:, :, h_i], dtype=float)
if sex_mask is not None:
p_tau_all = p_tau_all[sex_mask]
y_tau_all = y_tau_all[sex_mask]
for rr in compute_event_rate_at_topk_causes(p_tau_all, y_tau_all, topk_causes):
conc_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"horizon": float(tau),
"sex": sex_label,
**rr,
}
)
if bool(args.cause_concentration_write_random_baseline):
for rr in compute_random_ranking_baseline_topk(y_tau_all, topk_causes):
conc_base_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"horizon": float(tau),
"sex": sex_label,
**rr,
}
)
# Convenience slices for user-facing experiments (focus causes only).
cause_cif_focus = cif_full[:, top_cause_ids, :]
y_within_focus = y_cause_within_tau[:, top_cause_ids, :]
# ============================================================
# Experiment 2: High-risk capture points (+ optional curve)
# ============================================================
k_pcts = [int(x) for x in args.capture_k_pcts]
curve_max = int(args.capture_curve_max_pct)
curve_grid = list(range(1, curve_max + 1)
) if curve_max and curve_max > 0 else []
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
for h_i, tau in enumerate(args.eval_horizons):
for j, cause_id in enumerate(top_cause_ids.tolist()):
p = cause_cif_focus[:, j, h_i]
y = y_within_focus[:, j, h_i]
if sex_mask is not None:
p = p[sex_mask]
y = y[sex_mask]
for r in compute_capture_points(p, y, k_pcts):
cap_points_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"horizon": float(tau),
"sex": sex_label,
**r,
}
)
if curve_grid:
for r in compute_capture_points(p, y, curve_grid):
cap_curve_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"horizon": float(tau),
"sex": sex_label,
**r,
}
)
# Optionally write top-cause counts into the main results CSV as metric rows.
for tc in top_causes_meta:
model_rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_case_within_tau",
"horizon": float(tc["tau_years"]),
"cause": int(tc["cause_id"]),
"value": int(tc["n_case_within_tau"]),
"ci_low": "",
"ci_high": "",
}
)
model_rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_control_within_tau",
"horizon": float(tc["tau_years"]),
"cause": int(tc["cause_id"]),
"value": int(tc["n_control_within_tau"]),
"ci_low": "",
"ci_high": "",
}
)
model_rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_total_eval",
"horizon": float(tc["tau_years"]),
"cause": int(tc["cause_id"]),
"value": int(tc["n_total_eval"]),
"ci_low": "",
"ci_high": "",
}
)
# Write per-model results into the model's run directory.
model_calib_rows = calib_rows[calib_start:]
model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv")
model_summary_csv = os.path.join(run_dir, f"eval_summary_{tag}.csv")
model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv")
model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json")
write_results_csv(model_out_csv, model_rows)
write_simple_csv(
model_summary_csv,
["model_name", "metric_name", "horizon", "mean", "median", "n_valid"],
model_summary_rows,
)
write_calibration_bins_csv(model_calib_csv, model_calib_rows)
model_meta = {
"model_name": spec.name,
"checkpoint_path": spec.checkpoint_path,
"run_dir": run_dir,
"split": args.split,
"offset_years": args.offset_years,
"eval_horizons": [float(x) for x in args.eval_horizons],
"n_disease": int(dataset.n_disease),
"top_cause_ids": top_cause_ids.tolist(),
"top_causes": top_causes_meta,
"integrity": {spec.name: integrity_meta.get(spec.name, {})},
"paths": {
"results_csv": model_out_csv,
"summary_csv": model_summary_csv,
"calibration_bins_csv": model_calib_csv,
},
}
with open(model_meta_json, "w") as f:
json.dump(model_meta, f, indent=2)
print(f"Wrote per-model results to {model_out_csv}")
# Write global summary (across diseases) across all models.
write_simple_csv(
args.out_csv,
["model_name", "metric_name", "horizon", "mean", "median", "n_valid"],
summary_rows,
)
# Write calibration curve points to a separate CSV.
out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "."
calib_csv_path = os.path.join(out_dir, "calibration_bins.csv")
write_calibration_bins_csv(calib_csv_path, calib_rows)
# Write experiment exports
write_simple_csv(
os.path.join(export_dir, "lift_capture_points.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"horizon",
"sex",
"k_pct",
"n_targeted",
"events_targeted",
"events_total",
"event_capture_rate",
"precision_in_targeted",
],
cap_points_rows,
)
if cap_curve_rows:
write_simple_csv(
os.path.join(export_dir, "lift_capture_curve.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"horizon",
"sex",
"k_pct",
"n_targeted",
"events_targeted",
"events_total",
"event_capture_rate",
"precision_in_targeted",
],
cap_curve_rows,
)
write_simple_csv(
os.path.join(export_dir, "high_risk_cause_concentration.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"horizon",
"sex",
"topk",
"event_rate_mean",
"event_rate_median",
"recall_mean",
"recall_median",
"n_total",
"n_valid_recall",
],
conc_rows,
)
if conc_base_rows:
write_simple_csv(
os.path.join(
export_dir, "high_risk_cause_concentration_random_baseline.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"horizon",
"sex",
"topk",
"baseline_event_rate_mean",
"baseline_event_rate_p05",
"baseline_event_rate_p95",
"baseline_recall_mean",
"baseline_recall_p05",
"baseline_recall_p95",
"n_total",
"n_valid_recall",
"k_total",
"baseline_method",
],
conc_base_rows,
)
# Manifest markdown (stable, user-facing)
manifest_path = os.path.join(export_dir, "eval_exports_manifest.md")
with open(manifest_path, "w", encoding="utf-8") as f:
f.write(
"# Evaluation Exports Manifest\n\n"
"This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). "
"All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n"
"## Files\n\n"
"- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n"
"- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n"
"- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n"
"- high_risk_cause_concentration.csv: Event Rate@K and Recall@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart vs K.\n"
"- high_risk_cause_concentration_random_baseline.csv (optional): Random-ranking baseline for Event Rate@K and Recall@K with an uncertainty range (enabled by --cause_concentration_write_random_baseline).\n"
)
meta = {
"split": args.split,
"offset_years": args.offset_years,
"eval_horizons": [float(x) for x in args.eval_horizons],
"tau_max": float(tau_max),
"n_disease": int(dataset_for_top.n_disease),
"top_cause_ids": top_cause_ids.tolist(),
"top_causes": top_causes_meta,
"integrity": integrity_meta,
"notes": {
"label": "Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau])",
"primary_metrics": "cause_brier (CIF-based) and cause_ici (calibration)",
"secondary_metrics": "cause_auc (discrimination)",
"exclusions": "No all-cause aggregation; no next-event formulation; ECE not reported",
"warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.",
"exports_dir": export_dir,
"focus_causes": focus_causes,
},
}
with open(args.out_meta_json, "w") as f:
json.dump(meta, f, indent=2)
print(f"Wrote {args.out_csv} with {len(summary_rows)} rows")
print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows")
print(f"Wrote {args.out_meta_json}")
return 0
if __name__ == "__main__":
raise SystemExit(main())