Files
DeepHealth/evaluate_models.py

2056 lines
73 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import csv
import json
import math
import os
import random
import statistics
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from dataset import HealthDataset, health_collate_fn
from model import DelphiFork, SapDelphi, SimpleHead
# ============================================================
# Constants / defaults (aligned with evaluate_prompt.md)
# ============================================================
DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")]
DEFAULT_EVAL_HORIZONS = [0.72, 1.61, 3.84, 10.0]
DAYS_PER_YEAR = 365.25
DEFAULT_DEATH_CAUSE_ID = 1256
# ============================================================
# Model specs
# ============================================================
@dataclass(frozen=True)
class ModelSpec:
name: str
model_type: str # delphi_fork | sap_delphi
loss_type: str # exponential | discrete_time_cif
full_cov: bool
checkpoint_path: str
# ============================================================
# Determinism
# ============================================================
def set_deterministic(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# ============================================================
# Utilities
# ============================================================
def _parse_bool(x: Any) -> bool:
if isinstance(x, bool):
return x
s = str(x).strip().lower()
if s in {"true", "1", "yes", "y"}:
return True
if s in {"false", "0", "no", "n"}:
return False
raise ValueError(f"Cannot parse boolean: {x!r}")
def load_models_json(path: str) -> List[ModelSpec]:
with open(path, "r") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("models_json must be a list of model entries")
specs: List[ModelSpec] = []
for row in data:
specs.append(
ModelSpec(
name=str(row["name"]),
model_type=str(row["model_type"]),
loss_type=str(row["loss_type"]),
full_cov=_parse_bool(row["full_cov"]),
checkpoint_path=str(row["checkpoint_path"]),
)
)
return specs
def load_train_config_for_checkpoint(checkpoint_path: str) -> Dict[str, Any]:
run_dir = os.path.dirname(os.path.abspath(checkpoint_path))
cfg_path = os.path.join(run_dir, "train_config.json")
with open(cfg_path, "r") as f:
cfg = json.load(f)
return cfg
def build_eval_subset(
dataset: HealthDataset,
train_ratio: float,
val_ratio: float,
seed: int,
split: str,
):
n_total = len(dataset)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
train_ds, val_ds, test_ds = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed),
)
if split == "train":
return train_ds
if split == "val":
return val_ds
if split == "test":
return test_ds
if split == "all":
return dataset
raise ValueError("split must be one of: train, val, test, all")
# ============================================================
# Context selection (anti-leakage)
# ============================================================
def select_context_indices(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
offset_years: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Select per-sample prediction context index.
IMPORTANT SEMANTICS:
- The last observed token time is treated as the FOLLOW-UP END time.
- We pick the last valid token with time <= (followup_end_time - offset).
- We do NOT interpret followup_end_time as an event time.
Returns:
keep_mask: (B,) bool, which samples have a valid context
t_ctx: (B,) long, index into sequence
t_ctx_time: (B,) float, time (days) at context
"""
# valid tokens are event != 0 (padding is 0)
valid = event_seq != 0
lengths = valid.sum(dim=1)
last_idx = torch.clamp(lengths - 1, min=0)
b = torch.arange(event_seq.size(0), device=event_seq.device)
followup_end_time = time_seq[b, last_idx]
t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR)
eligible = valid & (time_seq <= t_cut.unsqueeze(1))
eligible_counts = eligible.sum(dim=1)
keep = eligible_counts > 0
t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long)
t_ctx_time = time_seq[b, t_ctx]
return keep, t_ctx, t_ctx_time
def next_event_after_context(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return next disease event after context.
Returns:
dt_years: (B,) float, time to next disease in years; +inf if none
cause: (B,) long, disease id in [0,K) for next event; -1 if none
"""
B, L = event_seq.shape
b = torch.arange(B, device=event_seq.device)
t0 = time_seq[b, t_ctx]
# Allow same-day events while excluding the context token itself.
# We rely on time-sorted sequences and select the FIRST valid future event by index.
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
idx_min = torch.where(
future, idxs, torch.full_like(idxs, L)).min(dim=1).values
has = idx_min < L
t_next = torch.where(has, idx_min, torch.zeros_like(idx_min))
t_next_time = time_seq[b, t_next]
dt_days = t_next_time - t0
dt_years = dt_days / DAYS_PER_YEAR
dt_years = torch.where(
has, dt_years, torch.full_like(dt_years, float("inf")))
cause_token = event_seq[b, t_next]
cause = (cause_token - 2).to(torch.long)
cause = torch.where(has, cause, torch.full_like(cause, -1))
return dt_years, cause
def multi_hot_ever_within_horizon(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
tau_years: float,
n_disease: int,
) -> torch.Tensor:
"""Binary labels: disease k occurs within tau after context (any occurrence)."""
B, L = event_seq.shape
b = torch.arange(B, device=event_seq.device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (tau_years * DAYS_PER_YEAR)
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
# Include same-day events after context, exclude any token at/before context index.
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
if not in_window.any():
return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
y[b_idx, disease_ids] = True
return y
def multi_hot_ever_after_context_anytime(
event_seq: torch.Tensor,
t_ctx: torch.Tensor,
n_disease: int,
) -> torch.Tensor:
"""Binary labels: disease k occurs ANYTIME after the prediction context.
This is Delphi2M-compatible for Task A case/control definition.
Same-day events are included as long as they occur after the context token index.
"""
B, L = event_seq.shape
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
if not future.any():
return y
b_idx, t_idx = future.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
y[b_idx, disease_ids] = True
return y
def multi_hot_selected_causes_within_horizon(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
tau_years: float,
cause_ids: torch.Tensor,
n_disease: int,
) -> torch.Tensor:
"""Labels for selected causes only: does cause k occur within tau after context?"""
B, L = event_seq.shape
device = event_seq.device
b = torch.arange(B, device=device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (tau_years * DAYS_PER_YEAR)
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device)
if not in_window.any():
return out
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
# Filter to selected causes via a boolean membership mask over the global disease space.
selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device)
selected[cause_ids] = True
keep = selected[disease_ids]
if not keep.any():
return out
b_idx = b_idx[keep]
disease_ids = disease_ids[keep]
# Map disease_id -> local index in cause_ids
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device)
lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device)
local = lookup[disease_ids]
out[b_idx, local] = True
return out
# ============================================================
# CIF conversion
# ============================================================
def cifs_from_exponential_logits(
logits: torch.Tensor,
taus: Sequence[float],
eps: float = 1e-6,
return_survival: bool = False,
) -> torch.Tensor:
"""Convert exponential cause-specific logits -> CIFs at taus.
logits: (B, K)
returns: (B, K, H) or (cif, survival) if return_survival
"""
hazards = F.softplus(logits) + eps
total = hazards.sum(dim=1, keepdim=True) # (B,1)
taus_t = torch.tensor(list(taus), device=logits.device,
dtype=hazards.dtype).view(1, 1, -1)
total_h = total.unsqueeze(-1) # (B,1,1)
# (1 - exp(-Lambda * tau))
one_minus_surv = 1.0 - torch.exp(-total_h * taus_t)
frac = hazards / torch.clamp(total, min=eps)
cif = frac.unsqueeze(-1) * one_minus_surv # (B,K,H)
# If total==0, set to 0
cif = torch.where(total_h > 0, cif, torch.zeros_like(cif))
if not return_survival:
return cif
survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H)
# Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors.
nonzero = (total.squeeze(1) > 0).unsqueeze(1)
survival = torch.where(nonzero, survival, torch.ones_like(survival))
return cif, survival
def cifs_from_discrete_time_logits(
logits: torch.Tensor,
bin_edges: Sequence[float],
taus: Sequence[float],
return_survival: bool = False,
) -> torch.Tensor:
"""Convert discrete-time CIF logits -> CIFs at taus.
logits: (B, K+1, n_bins+1)
bin_edges: len=n_bins+1 (including 0 and inf)
taus: subset of finite bin edges (recommended)
returns: (B, K, H) or (cif, survival) if return_survival
"""
if logits.ndim != 3:
raise ValueError("Expected logits shape (B, K+1, n_bins+1)")
B, K_plus_1, n_bins_plus_1 = logits.shape
K = K_plus_1 - 1
edges = [float(x) for x in bin_edges]
# drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf
finite_edges = [e for e in edges[1:] if math.isfinite(e)]
n_bins = len(finite_edges)
if n_bins_plus_1 != len(edges):
raise ValueError("logits last dim must match len(bin_edges)")
probs = torch.softmax(logits, dim=1) # (B, K+1, n_bins+1)
# use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot)
hazards = probs[:, :K, 1: 1 + n_bins] # (B,K,n_bins)
p_comp = probs[:, K, 1: 1 + n_bins] # (B,n_bins)
# survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v]
ones = torch.ones((B, 1), device=logits.device, dtype=probs.dtype)
cum = torch.cumprod(p_comp, dim=1)
s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # (B, n_bins)
cif_bins = torch.cumsum(s_prev.unsqueeze(
1) * hazards, dim=2) # (B,K,n_bins)
# Robust mapping from tau -> edge index (floating-point safe).
# taus are expected to align with bin edges, but may differ slightly due to parsing/serialization.
finite_edges_arr = np.asarray(finite_edges, dtype=float)
tau_to_idx: List[int] = []
for tau in taus:
tau_f = float(tau)
if not math.isfinite(tau_f):
raise ValueError("taus must be finite for discrete-time CIF")
diffs = np.abs(finite_edges_arr - tau_f)
j = int(np.argmin(diffs))
if diffs[j] > 1e-6:
raise ValueError(
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})"
)
tau_to_idx.append(j)
idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long)
cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H)
if not return_survival:
return cif
# Survival at each horizon = prod_{u <= idx[h]} p_comp[u]
survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v]
survival = survival_bins.index_select(dim=1, index=idx) # (B,H)
return cif, survival
# ============================================================
# CIF integrity checks
# ============================================================
def check_cif_integrity(
cause_cif: np.ndarray,
horizons: Sequence[float],
*,
tol: float = 1e-6,
name: str = "",
strict: bool = False,
survival: Optional[np.ndarray] = None,
) -> Tuple[bool, List[str]]:
"""Run basic sanity checks on CIF arrays.
Args:
cause_cif: (N, K, H)
horizons: length H
tol: tolerance for inequalities
name: model name for messages
strict: if True, raise ValueError on first failure
survival: optional (N, H) survival values at the same horizons
Returns:
(integrity_ok, notes)
"""
notes: List[str] = []
model_tag = f"[{name}] " if name else ""
def _fail(msg: str) -> None:
full = model_tag + msg
if strict:
raise ValueError(full)
print("WARNING:", full)
notes.append(msg)
cif = np.asarray(cause_cif)
if cif.ndim != 3:
_fail(f"integrity: expected cause_cif ndim=3, got {cif.ndim}")
return False, notes
N, K, H = cif.shape
if H != len(horizons):
_fail(
f"integrity: horizon length mismatch (H={H}, len(horizons)={len(horizons)})")
# (5) Finite
if not np.isfinite(cif).all():
_fail("integrity: non-finite values (NaN/Inf) in cause_cif")
# (1) Range
cmin = float(np.nanmin(cif))
cmax = float(np.nanmax(cif))
if cmin < -tol:
_fail(f"integrity: range min={cmin} < -tol={-tol}")
if cmax > 1.0 + tol:
_fail(f"integrity: range max={cmax} > 1+tol={1.0+tol}")
# (2) Monotonicity in horizons (per n,k)
diffs = np.diff(cif, axis=2)
if diffs.size > 0:
if np.nanmin(diffs) < -tol:
_fail("integrity: monotonicity violated (found negative diff along horizons)")
# (3) Probability mass: sum_k CIF <= 1
mass = np.sum(cif, axis=1) # (N,H)
mass_max = float(np.nanmax(mass))
if mass_max > 1.0 + tol:
_fail(f"integrity: probability mass exceeds 1 (max sum_k={mass_max})")
# (4) Conservation with survival, if provided
if survival is None:
warn = "integrity: survival not provided; skipping conservation check"
if strict:
# still skip (requested behavior), but keep message for context
notes.append(warn)
else:
print("WARNING:", model_tag + warn)
notes.append(warn)
else:
s = np.asarray(survival, dtype=float)
if s.shape != (N, H):
_fail(
f"integrity: survival shape mismatch (got {s.shape}, expected {(N, H)})")
else:
recon = 1.0 - s
err = np.abs(recon - mass)
# Discrete-time should be very tight; exponential may accumulate slightly more numerical error.
tol_cons = max(float(tol), 1e-4)
if float(np.nanmax(err)) > tol_cons:
_fail(
f"integrity: conservation violated (max |(1-surv)-sum_cif|={float(np.nanmax(err))}, tol={tol_cons})")
ok = len([n for n in notes if not n.endswith(
"skipping conservation check")]) == 0
return ok, notes
# ============================================================
# Metrics
# ============================================================
# --- Standard fast DeLong AUC variance + CI (ties handled via midranks) ---
def compute_midrank(x: np.ndarray) -> np.ndarray:
x = np.asarray(x, dtype=float)
order = np.argsort(x)
z = x[order]
n = x.shape[0]
t = np.zeros(n, dtype=float)
i = 0
while i < n:
j = i
while j < n and z[j] == z[i]:
j += 1
t[i:j] = 0.5 * (i + j - 1) + 1.0
i = j
out = np.empty(n, dtype=float)
out[order] = t
return out
def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]:
"""Fast DeLong method for computing AUC covariance.
predictions_sorted_transposed: shape (n_classifiers, n_examples) with positive examples first.
"""
preds = np.asarray(predictions_sorted_transposed, dtype=float)
m = int(label_1_count)
n = int(preds.shape[1] - m)
if m <= 0 or n <= 0:
return np.array([float("nan")]), np.array([[float("nan")]])
pos = preds[:, :m]
neg = preds[:, m:]
tx = np.array([compute_midrank(x) for x in pos])
ty = np.array([compute_midrank(x) for x in neg])
tz = np.array([compute_midrank(x) for x in preds])
aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n)
v01 = (tz[:, :m] - tx) / n
v10 = 1.0 - (tz[:, m:] - ty) / m
if v01.shape[0] > 1:
sx = np.cov(v01)
sy = np.cov(v10)
else:
# Single-classifier case: compute row-wise variance (do not flatten).
var_v01 = float(np.var(v01, axis=1, ddof=1)[0])
var_v10 = float(np.var(v10, axis=1, ddof=1)[0])
sx = np.array([[var_v01]])
sy = np.array([[var_v10]])
delong_cov = sx / m + sy / n
return aucs, delong_cov
def calc_auc_variance(ground_truth: np.ndarray, predictions: np.ndarray) -> Tuple[float, float]:
y = np.asarray(ground_truth, dtype=int)
p = np.asarray(predictions, dtype=float)
if y.ndim != 1 or p.ndim != 1 or y.shape[0] != p.shape[0]:
raise ValueError("calc_auc_variance expects 1D arrays of equal length")
m = int(np.sum(y == 1))
n = int(np.sum(y == 0))
if m == 0 or n == 0:
return float("nan"), float("nan")
order = np.argsort(-y) # positives first
preds_sorted = p[order]
aucs, cov = fastDeLong(preds_sorted[np.newaxis, :], m)
auc = float(aucs[0])
var = float(cov[0, 0])
return auc, var
def delong_ci(ground_truth: np.ndarray, predictions: np.ndarray, alpha: float = 0.95) -> Tuple[float, float, float]:
"""Return (auc, ci_low, ci_high) using DeLong variance and normal CI."""
auc, var = calc_auc_variance(ground_truth, predictions)
if not np.isfinite(var) or var <= 0:
print("WARNING: DeLong variance is non-positive or NaN; CI set to NaN")
return float(auc), float("nan"), float("nan")
sd = math.sqrt(var)
z = statistics.NormalDist().inv_cdf(1.0 - (1.0 - float(alpha)) / 2.0)
lo = max(0.0, auc - z * sd)
hi = min(1.0, auc + z * sd)
return float(auc), float(lo), float(hi)
def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Rank-based ROC AUC via MannWhitney U statistic (ties handled by midranks).
Returns NaN for degenerate labels.
"""
y = np.asarray(y_true, dtype=int)
s = np.asarray(y_score, dtype=float)
if y.ndim != 1 or s.ndim != 1 or y.shape[0] != s.shape[0]:
raise ValueError("roc_auc_rank expects 1D arrays of equal length")
m = int(np.sum(y == 1))
n = int(np.sum(y == 0))
if m == 0 or n == 0:
return float("nan")
ranks = compute_midrank(s)
sum_pos = float(np.sum(ranks[y == 1]))
auc = (sum_pos - m * (m + 1) / 2.0) / (m * n)
return float(auc)
def bootstrap_auc_ci(
scores: np.ndarray,
labels: np.ndarray,
n_bootstrap: int,
alpha: float = 0.95,
seed: int = 0,
) -> Tuple[float, float, float]:
"""Bootstrap CI for ROC AUC (percentile)."""
rng = np.random.default_rng(int(seed))
scores = np.asarray(scores, dtype=float)
labels = np.asarray(labels, dtype=int)
n = labels.shape[0]
if n == 0 or np.all(labels == labels[0]):
print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN")
return float("nan"), float("nan"), float("nan")
auc_full = roc_auc_rank(labels, scores)
if not np.isfinite(auc_full):
print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN")
return float("nan"), float("nan"), float("nan")
aucs: List[float] = []
for _ in range(int(n_bootstrap)):
idx = rng.integers(0, n, size=n)
yb = labels[idx]
if np.all(yb == yb[0]):
continue
pb = scores[idx]
auc = roc_auc_rank(yb, pb)
if np.isfinite(auc):
aucs.append(float(auc))
if len(aucs) < 10:
print("WARNING: bootstrap AUC CI has too few valid resamples; CI set to NaN")
return float(auc_full), float("nan"), float("nan")
lo_q = (1.0 - float(alpha)) / 2.0
hi_q = 1.0 - lo_q
lo = float(np.quantile(aucs, lo_q))
hi = float(np.quantile(aucs, hi_q))
return float(auc_full), lo, hi
def brier_score(p: np.ndarray, y: np.ndarray) -> float:
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
return float(np.mean((p - y) ** 2))
def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]:
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
# guard
if p.size == 0:
return {"bins": [], "ici": float("nan")}
edges = np.quantile(p, np.linspace(0.0, 1.0, n_bins + 1))
# make strictly increasing where possible
edges[0] = -np.inf
edges[-1] = np.inf
bins = []
ici_accum = 0.0
n = p.shape[0]
for i in range(n_bins):
mask = (p > edges[i]) & (p <= edges[i + 1])
if not np.any(mask):
continue
p_mean = float(np.mean(p[mask]))
y_mean = float(np.mean(y[mask]))
bins.append({"bin": i, "p_mean": p_mean,
"y_mean": y_mean, "n": int(mask.sum())})
ici_accum += abs(p_mean - y_mean)
ici = ici_accum / max(len(bins), 1)
return {"bins": bins, "ici": float(ici)}
def _safe_float(x: Any, default: float = float("nan")) -> float:
try:
return float(x)
except Exception:
return float(default)
def _ensure_dir(path: str) -> None:
os.makedirs(path, exist_ok=True)
def load_cause_names(path: str = "labels.csv") -> Dict[int, str]:
"""Load 0-based cause_id -> name mapping.
labels.csv is assumed to be one label per line, in disease-id order.
"""
if not os.path.exists(path):
return {}
mapping: Dict[int, str] = {}
with open(path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
name = line.strip()
if name:
mapping[int(i)] = name
return mapping
def pick_focus_causes(
*,
counts_within_tau: Optional[np.ndarray],
n_disease: int,
death_cause_id: int = DEFAULT_DEATH_CAUSE_ID,
k: int = 5,
) -> List[int]:
"""Pick focus causes for user-facing evaluation.
Rule:
1) Always include death_cause_id first.
2) Then add K additional causes by descending event count if available.
If counts_within_tau is None, fall back to descending cause_id coverage proxy.
Notes:
- counts_within_tau is expected to be shape (n_disease,).
- Deterministic: ties broken by smaller cause id.
"""
n_disease_i = int(n_disease)
if death_cause_id < 0 or death_cause_id >= n_disease_i:
print(
f"WARNING: death_cause_id={death_cause_id} out of range (n_disease={n_disease_i}); "
"it will be omitted from focus causes."
)
focus: List[int] = []
else:
focus = [int(death_cause_id)]
candidates = [i for i in range(n_disease_i) if i != int(death_cause_id)]
if counts_within_tau is not None:
c = np.asarray(counts_within_tau).astype(float)
if c.shape[0] != n_disease_i:
print(
"WARNING: counts_within_tau length mismatch; falling back to coverage proxy ordering."
)
counts_within_tau = None
else:
# Sort by (-count, cause_id)
order = sorted(candidates, key=lambda i: (-float(c[i]), int(i)))
order = [i for i in order if float(c[i]) > 0]
focus.extend([int(i) for i in order[: int(k)]])
if counts_within_tau is None:
# Fallback: deterministic coverage proxy (descending id, excluding death), then take K.
# (Real coverage requires data; this path is mostly for robustness.)
order = sorted(candidates, key=lambda i: (-int(i)))
focus.extend([int(i) for i in order[: int(k)]])
# De-dup while preserving order
seen = set()
out: List[int] = []
for cid in focus:
if cid not in seen:
out.append(cid)
seen.add(cid)
return out
def write_simple_csv(path: str, fieldnames: List[str], rows: List[Dict[str, Any]]) -> None:
_ensure_dir(os.path.dirname(os.path.abspath(path)) or ".")
with open(path, "w", newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow(r)
def _sex_slices(sex: Optional[np.ndarray]) -> List[Tuple[str, Optional[np.ndarray]]]:
"""Return list of (sex_label, mask) slices including an 'all' slice.
If sex is missing, returns only ('all', None).
"""
out: List[Tuple[str, Optional[np.ndarray]]] = [("all", None)]
if sex is None:
return out
s = np.asarray(sex)
if s.ndim != 1:
return out
for val in [0, 1]:
m = (s == val)
if int(np.sum(m)) > 0:
out.append((str(val), m))
return out
def _quantile_edges(p: np.ndarray, q: int) -> np.ndarray:
edges = np.quantile(p, np.linspace(0.0, 1.0, int(q) + 1))
edges = np.asarray(edges, dtype=float)
edges[0] = -np.inf
edges[-1] = np.inf
return edges
def compute_risk_stratification_bins(
p: np.ndarray,
y: np.ndarray,
*,
q_default: int = 10,
) -> Tuple[int, List[Dict[str, Any]], Dict[str, Any]]:
"""Compute quantile-based risk strata and a compact summary."""
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
n = int(p.shape[0])
if n == 0:
return 0, [], {
"y_overall": float("nan"),
"top_decile_y_rate": float("nan"),
"bottom_half_y_rate": float("nan"),
"lift_top10_vs_bottom50": float("nan"),
"slope_pred_vs_obs": float("nan"),
}
# Choose quantiles robustly.
q = int(q_default)
if n < 200:
q = 5
edges = _quantile_edges(p, q)
y_overall = float(np.mean(y))
bin_rows: List[Dict[str, Any]] = []
p_means: List[float] = []
y_rates: List[float] = []
n_bins: List[int] = []
for i in range(q):
mask = (p > edges[i]) & (p <= edges[i + 1])
nb = int(np.sum(mask))
if nb == 0:
# Keep the row for consistent plotting; set NaNs.
bin_rows.append(
{
"q": int(i + 1),
"n_bin": 0,
"p_mean": float("nan"),
"y_rate": float("nan"),
"y_overall": y_overall,
"lift_vs_overall": float("nan"),
}
)
continue
p_mean = float(np.mean(p[mask]))
y_rate = float(np.mean(y[mask]))
lift = float(y_rate / y_overall) if y_overall > 0 else float("nan")
bin_rows.append(
{
"q": int(i + 1),
"n_bin": nb,
"p_mean": p_mean,
"y_rate": y_rate,
"y_overall": y_overall,
"lift_vs_overall": lift,
}
)
p_means.append(p_mean)
y_rates.append(y_rate)
n_bins.append(nb)
# Summary
top_mask = (p > edges[q - 1]) & (p <= edges[q])
bot_half_mask = (p > edges[0]) & (p <= edges[q // 2])
top_y = float(np.mean(y[top_mask])) if int(
np.sum(top_mask)) > 0 else float("nan")
bot_y = float(np.mean(y[bot_half_mask])) if int(
np.sum(bot_half_mask)) > 0 else float("nan")
lift_top_vs_bottom = float(top_y / bot_y) if (np.isfinite(top_y)
and np.isfinite(bot_y) and bot_y > 0) else float("nan")
slope = float("nan")
if len(p_means) >= 2:
# Weighted least squares slope of y_rate ~ p_mean.
x = np.asarray(p_means, dtype=float)
yy = np.asarray(y_rates, dtype=float)
w = np.asarray(n_bins, dtype=float)
xm = float(np.average(x, weights=w))
ym = float(np.average(yy, weights=w))
denom = float(np.sum(w * (x - xm) ** 2))
if denom > 0:
slope = float(np.sum(w * (x - xm) * (yy - ym)) / denom)
summary = {
"y_overall": y_overall,
"top_decile_y_rate": top_y,
"bottom_half_y_rate": bot_y,
"lift_top10_vs_bottom50": lift_top_vs_bottom,
"slope_pred_vs_obs": slope,
}
return q, bin_rows, summary
def compute_capture_points(
p: np.ndarray,
y: np.ndarray,
k_pcts: Sequence[int],
) -> List[Dict[str, Any]]:
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
n = int(p.shape[0])
if n == 0:
return []
order = np.argsort(-p)
y_sorted = y[order]
events_total = float(np.sum(y_sorted))
rows: List[Dict[str, Any]] = []
for k in k_pcts:
kf = float(k)
n_targeted = int(math.ceil(n * kf / 100.0))
n_targeted = max(1, min(n_targeted, n))
events_targeted = float(np.sum(y_sorted[:n_targeted]))
capture = float(events_targeted /
events_total) if events_total > 0 else float("nan")
precision = float(events_targeted / float(n_targeted))
rows.append(
{
"k_pct": int(k),
"n_targeted": int(n_targeted),
"events_targeted": float(events_targeted),
"events_total": float(events_total),
"event_capture_rate": capture,
"precision_in_targeted": precision,
}
)
return rows
def make_horizon_groups(horizons: Sequence[float]) -> Tuple[List[Dict[str, Any]], Dict[float, str], str]:
"""Bucketize horizons into short/medium/long using the continuous-horizon rule."""
uniq = sorted({float(h) for h in horizons})
mapping: Dict[float, str] = {}
rows: List[Dict[str, Any]] = []
# First 4 short, next 4 medium, rest long.
for i, h in enumerate(uniq):
if i < 4:
g, gr = "short", 1
elif i < 8:
g, gr = "medium", 2
else:
g, gr = "long", 3
mapping[float(h)] = g
rows.append({"horizon": float(h), "group": g, "group_rank": int(gr)})
method = "continuous_unique_horizons_first4_next4_rest"
return rows, mapping, method
def count_occurs_within_horizon(
loader: DataLoader,
offset_years: float,
tau_years: float,
n_disease: int,
device: str,
) -> Tuple[np.ndarray, int]:
"""Count per-person occurrence within tau after the prediction context.
Returns counts[k] = number of individuals with disease k at least once in (t_ctx, t_ctx+tau].
"""
counts = torch.zeros((n_disease,), dtype=torch.long, device=device)
n_total_eval = 0
for batch in loader:
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
keep, t_ctx, _ = select_context_indices(
event_seq, time_seq, offset_years)
if not keep.any():
continue
n_total_eval += int(keep.sum().item())
event_seq = event_seq[keep]
time_seq = time_seq[keep]
t_ctx = t_ctx[keep]
B, L = event_seq.shape
b = torch.arange(B, device=device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (float(tau_years) * DAYS_PER_YEAR)
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
if not in_window.any():
continue
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
# unique per (person, disease) to count per-person within-window occurrence
key = b_idx.to(torch.long) * int(n_disease) + disease_ids
uniq = torch.unique(key)
uniq_disease = uniq % int(n_disease)
counts.scatter_add_(0, uniq_disease, torch.ones_like(
uniq_disease, dtype=torch.long))
return counts.detach().cpu().numpy(), int(n_total_eval)
# ============================================================
# Evaluation core
# ============================================================
def instantiate_model_and_head(
cfg: Dict[str, Any],
dataset: HealthDataset,
device: str,
checkpoint_path: str = "",
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]:
model_type = str(cfg["model_type"])
loss_type = str(cfg["loss_type"])
if loss_type == "exponential":
out_dims = [dataset.n_disease]
elif loss_type == "discrete_time_cif":
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
out_dims = [dataset.n_disease + 1, len(bin_edges)]
else:
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
if model_type == "delphi_fork":
backbone = DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
).to(device)
elif model_type == "sap_delphi":
# Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path.
emb_path = cfg.get("pretrained_emb_path", None)
if emb_path in {"", None}:
emb_path = cfg.get("pretrained_emd_path", None)
if emb_path in {"", None}:
run_dir = os.path.dirname(os.path.abspath(
checkpoint_path)) if checkpoint_path else ""
print(
f"WARNING: SapDelphi pretrained embedding path missing in config "
f"(expected 'pretrained_emb_path' or 'pretrained_emd_path'). "
f"checkpoint={checkpoint_path} run_dir={run_dir}"
)
backbone = SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
pretrained_weights_path=emb_path,
freeze_embeddings=True,
).to(device)
else:
raise ValueError(f"Unsupported model_type: {model_type}")
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
return backbone, head, loss_type, bin_edges
@torch.no_grad()
def predict_cifs_for_model(
backbone: torch.nn.Module,
head: torch.nn.Module,
loss_type: str,
bin_edges: Sequence[float],
loader: DataLoader,
device: str,
offset_years: float,
eval_horizons: Sequence[float],
top_cause_ids: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Run model and produce cause-specific, time-dependent CIF outputs.
Returns:
cause_cif: (N, topK, H)
cif_full: (N, K, H)
survival: (N, H)
y_cause_within_tau: (N, topK, H)
NOTE: Evaluation is cause-specific and horizon-specific (multi-disease risk).
"""
backbone.eval()
head.eval()
# We will accumulate in CPU lists, then concat.
cause_cif_list: List[np.ndarray] = []
cif_full_list: List[np.ndarray] = []
survival_list: List[np.ndarray] = []
y_cause_within_list: List[np.ndarray] = []
sex_list: List[np.ndarray] = []
top_cause_ids_t = torch.tensor(
top_cause_ids, dtype=torch.long, device=device)
for batch in loader:
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
keep, t_ctx, _ = select_context_indices(
event_seq, time_seq, offset_years)
if not keep.any():
continue
# filter batch
event_seq = event_seq[keep]
time_seq = time_seq[keep]
cont_feats = cont_feats[keep]
cate_feats = cate_feats[keep]
sexes_k = sexes[keep]
t_ctx = t_ctx[keep]
h = backbone(event_seq, time_seq, sexes_k,
cont_feats, cate_feats) # (B,L,D)
b = torch.arange(h.size(0), device=device)
c = h[b, t_ctx] # (B,D)
logits = head(c)
if loss_type == "exponential":
cif_full, survival = cifs_from_exponential_logits(
logits, eval_horizons, return_survival=True) # (B,K,H), (B,H)
elif loss_type == "discrete_time_cif":
cif_full, survival = cifs_from_discrete_time_logits(
# (B,K,H), (B,H)
logits, bin_edges, eval_horizons, return_survival=True)
else:
raise ValueError(f"Unsupported loss_type: {loss_type}")
cause_cif = cif_full.index_select(
dim=1, index=top_cause_ids_t) # (B,topK,H)
# Within-horizon labels for cause-specific CIF quality + discrimination.
n_disease = int(cif_full.size(1))
y_within_top = torch.stack(
[
multi_hot_selected_causes_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau),
cause_ids=top_cause_ids_t,
n_disease=n_disease,
).to(torch.float32)
for tau in eval_horizons
],
dim=2,
) # (B,topK,H)
cause_cif_list.append(cause_cif.detach().cpu().numpy())
cif_full_list.append(cif_full.detach().cpu().numpy())
survival_list.append(survival.detach().cpu().numpy())
y_cause_within_list.append(y_within_top.detach().cpu().numpy())
sex_list.append(sexes_k.detach().cpu().numpy())
if not cause_cif_list:
raise RuntimeError(
"No valid samples for evaluation (all batches filtered out by offset).")
cause_cif = np.concatenate(cause_cif_list, axis=0)
cif_full = np.concatenate(cif_full_list, axis=0)
survival = np.concatenate(survival_list, axis=0)
y_cause_within = np.concatenate(y_cause_within_list, axis=0)
sex = np.concatenate(
sex_list, axis=0) if sex_list else np.array([], dtype=int)
return cause_cif, cif_full, survival, y_cause_within, sex
def pick_top_causes(y_ever: np.ndarray, top_k: int) -> np.ndarray:
counts = y_ever.sum(axis=0)
order = np.argsort(-counts)
order = order[counts[order] > 0]
return order[:top_k]
def evaluate_one_model(
model_name: str,
cause_cif: np.ndarray,
y_cause_within_tau: np.ndarray,
eval_horizons: Sequence[float],
top_cause_ids: np.ndarray,
out_rows: List[Dict[str, Any]],
calib_rows: List[Dict[str, Any]],
auc_ci_method: str,
bootstrap_n: int,
n_calib_bins: int = 10,
) -> None:
# Cause-specific, time-dependent metrics per horizon.
for h_i, tau in enumerate(eval_horizons):
p_tau = cause_cif[:, :, h_i] # (N, topK)
y_tau = y_cause_within_tau[:, :, h_i] # (N, topK)
for j, cause_id in enumerate(top_cause_ids.tolist()):
p = p_tau[:, j]
y = y_tau[:, j]
# Primary: CIF-based Brier score + ICI (calibration).
out_rows.append(
{
"model_name": model_name,
"metric_name": "cause_brier",
"horizon": float(tau),
"cause": int(cause_id),
"value": brier_score(p, y),
"ci_low": "",
"ci_high": "",
}
)
cal = calibration_deciles(p, y, n_bins=n_calib_bins)
out_rows.append(
{
"model_name": model_name,
"metric_name": "cause_ici",
"horizon": float(tau),
"cause": int(cause_id),
"value": cal["ici"],
"ci_low": "",
"ci_high": "",
}
)
# Secondary: discrimination via AUC at the same horizon.
if auc_ci_method == "none":
auc, lo, hi = float("nan"), float("nan"), float("nan")
elif auc_ci_method == "bootstrap":
auc, lo, hi = bootstrap_auc_ci(
p, y, n_bootstrap=bootstrap_n, alpha=0.95)
else:
auc, lo, hi = delong_ci(y, p, alpha=0.95)
out_rows.append(
{
"model_name": model_name,
"metric_name": "cause_auc",
"horizon": float(tau),
"cause": int(cause_id),
"value": auc,
"ci_low": lo,
"ci_high": hi,
}
)
# Calibration curve bins for this cause + horizon.
for binfo in cal.get("bins", []):
calib_rows.append(
{
"model_name": model_name,
"task": "cause_k",
"horizon": float(tau),
"cause_id": int(cause_id),
"bin_index": int(binfo["bin"]),
"p_mean": float(binfo["p_mean"]),
"y_mean": float(binfo["y_mean"]),
"n_in_bin": int(binfo["n"]),
}
)
def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None:
fieldnames = [
"model_name",
"task",
"horizon",
"cause_id",
"bin_index",
"p_mean",
"y_mean",
"n_in_bin",
]
with open(path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow(r)
def write_results_csv(path: str, rows: List[Dict[str, Any]]) -> None:
fieldnames = [
"model_name",
"metric_name",
"horizon",
"cause",
"value",
"ci_low",
"ci_high",
]
with open(path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow(r)
def _make_eval_tag(split: str, offset_years: float) -> str:
"""Short tag for filenames written into run directories."""
off = f"{float(offset_years):.4f}".rstrip("0").rstrip(".")
return f"{split}_offset{off}y"
def main() -> int:
ap = argparse.ArgumentParser(
description="Unified downstream evaluation via CIFs")
ap.add_argument("--models_json", type=str, required=True,
help="Path to models list JSON")
ap.add_argument("--data_prefix", type=str,
default="ukb", help="Dataset prefix")
ap.add_argument("--split", type=str, default="test",
choices=["train", "val", "test", "all"], help="Which split to evaluate")
ap.add_argument("--offset_years", type=float, default=0.5,
help="Anti-leakage offset (years)")
ap.add_argument("--eval_horizons", type=float,
nargs="*", default=DEFAULT_EVAL_HORIZONS)
ap.add_argument("--top_k_causes", type=int, default=50)
ap.add_argument("--batch_size", type=int, default=128)
ap.add_argument("--num_workers", type=int, default=0)
ap.add_argument("--seed", type=int, default=123)
ap.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
ap.add_argument("--out_csv", type=str, default="eval_results.csv")
ap.add_argument("--out_meta_json", type=str, default="eval_meta.json")
# Integrity checks
ap.add_argument("--integrity_strict", action="store_true", default=False)
ap.add_argument("--integrity_tol", type=float, default=1e-6)
# AUC CI methods
ap.add_argument(
"--auc_ci_method",
type=str,
default="delong",
choices=["delong", "bootstrap", "none"],
)
ap.add_argument("--bootstrap_n", type=int, default=2000)
# Export settings for user-facing experiments
ap.add_argument("--export_dir", type=str, default="eval_exports")
ap.add_argument("--death_cause_id", type=int,
default=DEFAULT_DEATH_CAUSE_ID)
ap.add_argument("--focus_k", type=int, default=5,
help="Additional non-death causes to include")
ap.add_argument("--capture_k_pcts", type=int,
nargs="*", default=[1, 5, 10, 20])
ap.add_argument(
"--capture_curve_max_pct",
type=int,
default=50,
help="If >0, also export a dense capture curve for k=1..max_pct",
)
args = ap.parse_args()
set_deterministic(args.seed)
specs = load_models_json(args.models_json)
if not specs:
raise ValueError("No models provided")
export_dir = str(args.export_dir)
_ensure_dir(export_dir)
cause_names = load_cause_names("labels.csv")
# Determine top-K causes from the evaluation split only (model-agnostic),
# aligned to time-dependent risk: occurrence within tau_max after context.
first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path)
cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [
"bmi", "smoking", "alcohol"]
dataset_for_top = HealthDataset(
data_prefix=args.data_prefix, covariate_list=cov_list)
subset_for_top = build_eval_subset(
dataset_for_top,
train_ratio=float(first_cfg.get("train_ratio", 0.7)),
val_ratio=float(first_cfg.get("val_ratio", 0.15)),
seed=int(first_cfg.get("random_seed", 42)),
split=args.split,
)
loader_top = DataLoader(
subset_for_top,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=health_collate_fn,
)
tau_max = float(max(args.eval_horizons))
counts, n_total_eval = count_occurs_within_horizon(
loader=loader_top,
offset_years=args.offset_years,
tau_years=tau_max,
n_disease=dataset_for_top.n_disease,
device=args.device,
)
focus_causes = pick_focus_causes(
counts_within_tau=counts,
n_disease=int(dataset_for_top.n_disease),
death_cause_id=int(args.death_cause_id),
k=int(args.focus_k),
)
top_cause_ids = np.asarray(focus_causes, dtype=int)
# Export the chosen focus causes.
focus_rows: List[Dict[str, Any]] = []
for r, cid in enumerate(focus_causes, start=1):
row: Dict[str, Any] = {"cause": int(cid), "rank": int(r)}
if cid in cause_names:
row["cause_name"] = cause_names[cid]
focus_rows.append(row)
focus_fieldnames = ["cause", "rank"] + \
(["cause_name"] if any("cause_name" in r for r in focus_rows) else [])
write_simple_csv(os.path.join(export_dir, "focus_causes.csv"),
focus_fieldnames, focus_rows)
# Metadata for focus causes (within tau_max).
top_causes_meta: List[Dict[str, Any]] = []
for cid in focus_causes:
n_case = int(counts[int(cid)]) if int(cid) < int(counts.shape[0]) else 0
top_causes_meta.append(
{
"cause_id": int(cid),
"tau_years": float(tau_max),
"n_case_within_tau": n_case,
"n_control_within_tau": int(n_total_eval - n_case),
"n_total_eval": int(n_total_eval),
}
)
# Horizon groups for Experiment 3
hg_rows, horizon_to_group, hg_method = make_horizon_groups(
args.eval_horizons)
write_simple_csv(
os.path.join(export_dir, "horizon_groups.csv"),
["horizon", "group", "group_rank"],
hg_rows,
)
rows: List[Dict[str, Any]] = []
calib_rows: List[Dict[str, Any]] = []
# Experiment exports (accumulated across models)
rs_bins_rows: List[Dict[str, Any]] = []
rs_sum_rows: List[Dict[str, Any]] = []
cap_points_rows: List[Dict[str, Any]] = []
cap_curve_rows: List[Dict[str, Any]] = []
cal_group_sum_rows: List[Dict[str, Any]] = []
cal_group_bins_rows: List[Dict[str, Any]] = []
# Track per-model integrity status for meta JSON.
integrity_meta: Dict[str, Any] = {}
# Evaluate each model
for spec in specs:
run_dir = os.path.dirname(os.path.abspath(spec.checkpoint_path))
tag = _make_eval_tag(args.split, float(args.offset_years))
# Remember list offsets so we can write per-model slices to the model's run_dir.
rows_start = len(rows)
calib_start = len(calib_rows)
cfg = load_train_config_for_checkpoint(spec.checkpoint_path)
# Identifiers for consistent exports
model_id = str(spec.name)
model_type = str(
cfg.get("model_type", spec.model_type if hasattr(spec, "model_type") else ""))
loss_type_id = str(
cfg.get("loss_type", spec.loss_type if hasattr(spec, "loss_type") else ""))
age_encoder = str(cfg.get("age_encoder", ""))
cov_type = "full" if _parse_bool(
cfg.get("full_cov", False)) else "partial"
cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [
"bmi", "smoking", "alcohol"]
dataset = HealthDataset(
data_prefix=args.data_prefix, covariate_list=cov_list)
subset = build_eval_subset(
dataset,
train_ratio=float(cfg.get("train_ratio", 0.7)),
val_ratio=float(cfg.get("val_ratio", 0.15)),
seed=int(cfg.get("random_seed", 42)),
split=args.split,
)
loader = DataLoader(
subset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=health_collate_fn,
)
backbone, head, loss_type, bin_edges = instantiate_model_and_head(
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
ckpt = torch.load(spec.checkpoint_path, map_location=args.device)
backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
head.load_state_dict(ckpt["head_state_dict"], strict=True)
(
cause_cif,
cif_full,
survival,
y_cause_within_tau,
sex,
) = predict_cifs_for_model(
backbone,
head,
loss_type,
bin_edges,
loader,
args.device,
args.offset_years,
args.eval_horizons,
top_cause_ids,
)
# CIF integrity checks before metrics.
integrity_ok, integrity_notes = check_cif_integrity(
cif_full,
args.eval_horizons,
tol=float(args.integrity_tol),
name=spec.name,
strict=bool(args.integrity_strict),
survival=survival,
)
integrity_meta[spec.name] = {
"integrity_ok": bool(integrity_ok),
"integrity_notes": integrity_notes,
}
evaluate_one_model(
model_name=spec.name,
cause_cif=cause_cif,
y_cause_within_tau=y_cause_within_tau,
eval_horizons=args.eval_horizons,
top_cause_ids=top_cause_ids,
out_rows=rows,
calib_rows=calib_rows,
auc_ci_method=str(args.auc_ci_method),
bootstrap_n=int(args.bootstrap_n),
)
# ============================================================
# Experiment 1: Risk stratification bins + summary
# ============================================================
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
for h_i, tau in enumerate(args.eval_horizons):
for j, cause_id in enumerate(top_cause_ids.tolist()):
p = cause_cif[:, j, h_i]
y = y_cause_within_tau[:, j, h_i]
if sex_mask is not None:
p = p[sex_mask]
y = y[sex_mask]
q_used, bin_rows, summary = compute_risk_stratification_bins(
p, y, q_default=10)
for br in bin_rows:
rs_bins_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"horizon": float(tau),
"sex": sex_label,
"q": int(br["q"]),
"n_bin": int(br["n_bin"]),
"p_mean": _safe_float(br["p_mean"]),
"y_rate": _safe_float(br["y_rate"]),
"y_overall": _safe_float(br["y_overall"]),
"lift_vs_overall": _safe_float(br["lift_vs_overall"]),
"q_total": int(q_used),
}
)
rs_sum_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"horizon": float(tau),
"sex": sex_label,
"q_total": int(q_used),
"top_decile_y_rate": _safe_float(summary["top_decile_y_rate"]),
"bottom_half_y_rate": _safe_float(summary["bottom_half_y_rate"]),
"lift_top10_vs_bottom50": _safe_float(summary["lift_top10_vs_bottom50"]),
"slope_pred_vs_obs": _safe_float(summary["slope_pred_vs_obs"]),
}
)
# ============================================================
# Experiment 2: High-risk capture points (+ optional curve)
# ============================================================
k_pcts = [int(x) for x in args.capture_k_pcts]
curve_max = int(args.capture_curve_max_pct)
curve_grid = list(range(1, curve_max + 1)
) if curve_max and curve_max > 0 else []
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
for h_i, tau in enumerate(args.eval_horizons):
for j, cause_id in enumerate(top_cause_ids.tolist()):
p = cause_cif[:, j, h_i]
y = y_cause_within_tau[:, j, h_i]
if sex_mask is not None:
p = p[sex_mask]
y = y[sex_mask]
for r in compute_capture_points(p, y, k_pcts):
cap_points_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"horizon": float(tau),
"sex": sex_label,
**r,
}
)
if curve_grid:
for r in compute_capture_points(p, y, curve_grid):
cap_curve_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"horizon": float(tau),
"sex": sex_label,
**r,
}
)
# ============================================================
# Experiment 3: Short/Medium/Long horizon-group calibration
# ============================================================
# Per-horizon metrics for grouping
# Build a dict for quick access: (cause_id, horizon) -> (brier, ici)
per_h: Dict[Tuple[int, float], Dict[str, float]] = {}
for rr in rows[rows_start:]:
if rr.get("model_name") != spec.name:
continue
if rr.get("metric_name") not in {"cause_brier", "cause_ici"}:
continue
try:
cid = int(rr.get("cause"))
except Exception:
continue
h = _safe_float(rr.get("horizon"))
if not np.isfinite(h):
continue
key = (cid, float(h))
d = per_h.get(key, {})
d[str(rr.get("metric_name"))] = _safe_float(rr.get("value"))
per_h[key] = d
# Compute group summaries and pooled bins using the same quantile bins as exp1 (per slice).
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
for j, cause_id in enumerate(top_cause_ids.tolist()):
# Decide Q per slice for pooled reliability curve
n_slice = int(np.sum(sex_mask)) if sex_mask is not None else int(
sex.shape[0])
q_pool = 10 if n_slice >= 200 else 5
# Collect per-horizon brier/ici values
group_vals: Dict[str, Dict[str, List[float]]] = {"short": {"brier": [], "ici": [
]}, "medium": {"brier": [], "ici": []}, "long": {"brier": [], "ici": []}}
group_n_total: Dict[str, int] = {
"short": 0, "medium": 0, "long": 0}
# Pooled bins: group -> q -> accumulators
pooled: Dict[str, Dict[int, Dict[str, float]]] = {
"short": {}, "medium": {}, "long": {}}
for h_i, tau in enumerate(args.eval_horizons):
g = horizon_to_group.get(float(tau), "long")
# brier/ici per horizon (already computed at full-sample level)
d = per_h.get((int(cause_id), float(tau)), {})
brier_h = _safe_float(d.get("cause_brier"))
ici_h = _safe_float(d.get("cause_ici"))
if np.isfinite(brier_h):
group_vals[g]["brier"].append(brier_h)
if np.isfinite(ici_h):
group_vals[g]["ici"].append(ici_h)
# pooled reliability bins from raw p/y
p = cause_cif[:, j, h_i]
y = y_cause_within_tau[:, j, h_i]
if sex_mask is not None:
p = p[sex_mask]
y = y[sex_mask]
if p.size == 0:
continue
edges = _quantile_edges(p, q_pool)
for qi in range(q_pool):
m = (p > edges[qi]) & (p <= edges[qi + 1])
nb = int(np.sum(m))
if nb == 0:
continue
pm = float(np.mean(p[m]))
yr = float(np.mean(y[m]))
acc = pooled[g].get(
qi + 1, {"n": 0.0, "p_sum": 0.0, "y_sum": 0.0})
acc["n"] += float(nb)
acc["p_sum"] += float(nb) * pm
acc["y_sum"] += float(nb) * yr
pooled[g][qi + 1] = acc
group_n_total[g] = max(group_n_total[g], int(p.size))
for g in ["short", "medium", "long"]:
bvals = group_vals[g]["brier"]
ivals = group_vals[g]["ici"]
cal_group_sum_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"sex": sex_label,
"horizon_group": g,
"brier_mean": float(np.mean(bvals)) if bvals else float("nan"),
"brier_median": float(np.median(bvals)) if bvals else float("nan"),
"ici_mean": float(np.mean(ivals)) if ivals else float("nan"),
"ici_median": float(np.median(ivals)) if ivals else float("nan"),
"n_total": int(group_n_total[g]),
"horizon_grouping_method": hg_method,
}
)
for qi in range(1, q_pool + 1):
acc = pooled[g].get(qi)
if not acc or float(acc.get("n", 0.0)) <= 0:
continue
n_bin = float(acc["n"])
cal_group_bins_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"sex": sex_label,
"horizon_group": g,
"q": int(qi),
"n_bin": int(n_bin),
"p_mean": float(acc["p_sum"] / n_bin),
"y_rate": float(acc["y_sum"] / n_bin),
"q_total": int(q_pool),
"horizon_grouping_method": hg_method,
}
)
# Optionally write top-cause counts into the main results CSV as metric rows.
for tc in top_causes_meta:
rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_case_within_tau",
"horizon": float(tc["tau_years"]),
"cause": int(tc["cause_id"]),
"value": int(tc["n_case_within_tau"]),
"ci_low": "",
"ci_high": "",
}
)
rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_control_within_tau",
"horizon": float(tc["tau_years"]),
"cause": int(tc["cause_id"]),
"value": int(tc["n_control_within_tau"]),
"ci_low": "",
"ci_high": "",
}
)
rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_total_eval",
"horizon": float(tc["tau_years"]),
"cause": int(tc["cause_id"]),
"value": int(tc["n_total_eval"]),
"ci_low": "",
"ci_high": "",
}
)
# Write per-model results into the model's run directory.
model_rows = rows[rows_start:]
model_calib_rows = calib_rows[calib_start:]
model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv")
model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv")
model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json")
write_results_csv(model_out_csv, model_rows)
write_calibration_bins_csv(model_calib_csv, model_calib_rows)
model_meta = {
"model_name": spec.name,
"checkpoint_path": spec.checkpoint_path,
"run_dir": run_dir,
"split": args.split,
"offset_years": args.offset_years,
"eval_horizons": [float(x) for x in args.eval_horizons],
"top_k_causes": int(args.top_k_causes),
"top_cause_ids": top_cause_ids.tolist(),
"top_causes": top_causes_meta,
"integrity": {spec.name: integrity_meta.get(spec.name, {})},
"paths": {
"results_csv": model_out_csv,
"calibration_bins_csv": model_calib_csv,
},
}
with open(model_meta_json, "w") as f:
json.dump(model_meta, f, indent=2)
print(f"Wrote per-model results to {model_out_csv}")
write_results_csv(args.out_csv, rows)
# Write calibration curve points to a separate CSV.
out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "."
calib_csv_path = os.path.join(out_dir, "calibration_bins.csv")
write_calibration_bins_csv(calib_csv_path, calib_rows)
# Write experiment exports
write_simple_csv(
os.path.join(export_dir, "risk_stratification_bins.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"horizon",
"sex",
"q",
"n_bin",
"p_mean",
"y_rate",
"y_overall",
"lift_vs_overall",
"q_total",
],
rs_bins_rows,
)
write_simple_csv(
os.path.join(export_dir, "risk_stratification_summary.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"horizon",
"sex",
"q_total",
"top_decile_y_rate",
"bottom_half_y_rate",
"lift_top10_vs_bottom50",
"slope_pred_vs_obs",
],
rs_sum_rows,
)
write_simple_csv(
os.path.join(export_dir, "lift_capture_points.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"horizon",
"sex",
"k_pct",
"n_targeted",
"events_targeted",
"events_total",
"event_capture_rate",
"precision_in_targeted",
],
cap_points_rows,
)
if cap_curve_rows:
write_simple_csv(
os.path.join(export_dir, "lift_capture_curve.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"horizon",
"sex",
"k_pct",
"n_targeted",
"events_targeted",
"events_total",
"event_capture_rate",
"precision_in_targeted",
],
cap_curve_rows,
)
write_simple_csv(
os.path.join(export_dir, "calibration_groups_summary.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"sex",
"horizon_group",
"brier_mean",
"brier_median",
"ici_mean",
"ici_median",
"n_total",
"horizon_grouping_method",
],
cal_group_sum_rows,
)
write_simple_csv(
os.path.join(export_dir, "calibration_groups_bins.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"sex",
"horizon_group",
"q",
"n_bin",
"p_mean",
"y_rate",
"q_total",
"horizon_grouping_method",
],
cal_group_bins_rows,
)
# Manifest markdown (stable, user-facing)
manifest_path = os.path.join(export_dir, "eval_exports_manifest.md")
with open(manifest_path, "w", encoding="utf-8") as f:
f.write(
"# Evaluation Exports Manifest\n\n"
"This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). "
"All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n"
"## Files\n\n"
"- focus_causes.csv: The deterministically selected focus causes (Death + top-K). Intended plot: bar of event support + label table.\n"
"- horizon_groups.csv: Mapping from each horizon to short/medium/long buckets. Intended plot: annotate calibration comparisons.\n"
"- risk_stratification_bins.csv: Quantile bins (deciles or quintiles) with predicted vs observed event rates and lift. Intended plot: reliability-by-risk-tier lines.\n"
"- risk_stratification_summary.csv: Compact stratification summaries (top decile vs bottom half lift, slope). Intended plot: slide-friendly comparison table.\n"
"- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n"
"- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n"
"- calibration_groups_summary.csv: Short/medium/long aggregated Brier/ICI (mean/median). Intended plot: grouped bar chart by horizon bucket.\n"
"- calibration_groups_bins.csv: Pooled reliability points per horizon bucket (weighted by bin size). Intended plot: 3-panel reliability curves per model.\n"
)
meta = {
"split": args.split,
"offset_years": args.offset_years,
"eval_horizons": [float(x) for x in args.eval_horizons],
"tau_max": float(tau_max),
"top_k_causes": int(args.top_k_causes),
"top_cause_ids": top_cause_ids.tolist(),
"top_causes": top_causes_meta,
"integrity": integrity_meta,
"notes": {
"label": "Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau])",
"primary_metrics": "cause_brier (CIF-based) and cause_ici (calibration)",
"secondary_metrics": "cause_auc (discrimination) with optional CI",
"exclusions": "No all-cause aggregation; no next-event formulation; ECE not reported",
"warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.",
"exports_dir": export_dir,
"focus_causes": focus_causes,
"horizon_grouping_method": hg_method,
},
}
with open(args.out_meta_json, "w") as f:
json.dump(meta, f, indent=2)
print(f"Wrote {args.out_csv} with {len(rows)} rows")
print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows")
print(f"Wrote {args.out_meta_json}")
return 0
if __name__ == "__main__":
raise SystemExit(main())