Files
DeepHealth/evaluate_models.py

1554 lines
53 KiB
Python
Raw Normal View History

import argparse
import csv
import json
import math
import os
import random
import statistics
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from dataset import HealthDataset, health_collate_fn
from losses import DiscreteTimeCIFNLLLoss
from model import DelphiFork, SapDelphi, SimpleHead
# ============================================================
# Constants / defaults (aligned with evaluate_prompt.md)
# ============================================================
DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")]
DEFAULT_EVAL_HORIZONS = [0.72, 1.61, 3.84, 10.0]
DAYS_PER_YEAR = 365.25
# ============================================================
# Model specs
# ============================================================
@dataclass(frozen=True)
class ModelSpec:
name: str
model_type: str # delphi_fork | sap_delphi
loss_type: str # exponential | discrete_time_cif
full_cov: bool
checkpoint_path: str
# ============================================================
# Determinism
# ============================================================
def set_deterministic(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# ============================================================
# Utilities
# ============================================================
def _parse_bool(x: Any) -> bool:
if isinstance(x, bool):
return x
s = str(x).strip().lower()
if s in {"true", "1", "yes", "y"}:
return True
if s in {"false", "0", "no", "n"}:
return False
raise ValueError(f"Cannot parse boolean: {x!r}")
def load_models_json(path: str) -> List[ModelSpec]:
with open(path, "r") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("models_json must be a list of model entries")
specs: List[ModelSpec] = []
for row in data:
specs.append(
ModelSpec(
name=str(row["name"]),
model_type=str(row["model_type"]),
loss_type=str(row["loss_type"]),
full_cov=_parse_bool(row["full_cov"]),
checkpoint_path=str(row["checkpoint_path"]),
)
)
return specs
def load_train_config_for_checkpoint(checkpoint_path: str) -> Dict[str, Any]:
run_dir = os.path.dirname(os.path.abspath(checkpoint_path))
cfg_path = os.path.join(run_dir, "train_config.json")
with open(cfg_path, "r") as f:
cfg = json.load(f)
return cfg
def build_eval_subset(
dataset: HealthDataset,
train_ratio: float,
val_ratio: float,
seed: int,
split: str,
):
n_total = len(dataset)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
train_ds, val_ds, test_ds = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed),
)
if split == "train":
return train_ds
if split == "val":
return val_ds
if split == "test":
return test_ds
if split == "all":
return dataset
raise ValueError("split must be one of: train, val, test, all")
# ============================================================
# Context selection (anti-leakage)
# ============================================================
def select_context_indices(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
offset_years: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Select per-sample prediction context index.
IMPORTANT SEMANTICS:
- The last observed token time is treated as the FOLLOW-UP END time.
- We pick the last valid token with time <= (followup_end_time - offset).
- We do NOT interpret followup_end_time as an event time.
Returns:
keep_mask: (B,) bool, which samples have a valid context
t_ctx: (B,) long, index into sequence
t_ctx_time: (B,) float, time (days) at context
"""
# valid tokens are event != 0 (padding is 0)
valid = event_seq != 0
lengths = valid.sum(dim=1)
last_idx = torch.clamp(lengths - 1, min=0)
b = torch.arange(event_seq.size(0), device=event_seq.device)
followup_end_time = time_seq[b, last_idx]
t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR)
eligible = valid & (time_seq <= t_cut.unsqueeze(1))
eligible_counts = eligible.sum(dim=1)
keep = eligible_counts > 0
t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long)
t_ctx_time = time_seq[b, t_ctx]
return keep, t_ctx, t_ctx_time
def next_event_after_context(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return next disease event after context.
Returns:
dt_years: (B,) float, time to next disease in years; +inf if none
cause: (B,) long, disease id in [0,K) for next event; -1 if none
"""
B, L = event_seq.shape
b = torch.arange(B, device=event_seq.device)
t0 = time_seq[b, t_ctx]
# Allow same-day events while excluding the context token itself.
# We rely on time-sorted sequences and select the FIRST valid future event by index.
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
idx_min = torch.where(
future, idxs, torch.full_like(idxs, L)).min(dim=1).values
has = idx_min < L
t_next = torch.where(has, idx_min, torch.zeros_like(idx_min))
t_next_time = time_seq[b, t_next]
dt_days = t_next_time - t0
dt_years = dt_days / DAYS_PER_YEAR
dt_years = torch.where(
has, dt_years, torch.full_like(dt_years, float("inf")))
cause_token = event_seq[b, t_next]
cause = (cause_token - 2).to(torch.long)
cause = torch.where(has, cause, torch.full_like(cause, -1))
return dt_years, cause
def multi_hot_ever_within_horizon(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
tau_years: float,
n_disease: int,
) -> torch.Tensor:
"""Binary labels: disease k occurs within tau after context (any occurrence)."""
B, L = event_seq.shape
b = torch.arange(B, device=event_seq.device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (tau_years * DAYS_PER_YEAR)
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
# Include same-day events after context, exclude any token at/before context index.
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
if not in_window.any():
return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
y[b_idx, disease_ids] = True
return y
def multi_hot_ever_after_context_anytime(
event_seq: torch.Tensor,
t_ctx: torch.Tensor,
n_disease: int,
) -> torch.Tensor:
"""Binary labels: disease k occurs ANYTIME after the prediction context.
This is Delphi2M-compatible for Task A case/control definition.
Same-day events are included as long as they occur after the context token index.
"""
B, L = event_seq.shape
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
if not future.any():
return y
b_idx, t_idx = future.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
y[b_idx, disease_ids] = True
return y
def multi_hot_selected_causes_within_horizon(
event_seq: torch.Tensor,
time_seq: torch.Tensor,
t_ctx: torch.Tensor,
tau_years: float,
cause_ids: torch.Tensor,
n_disease: int,
) -> torch.Tensor:
"""Labels for selected causes only: does cause k occur within tau after context?"""
B, L = event_seq.shape
device = event_seq.device
b = torch.arange(B, device=device)
t0 = time_seq[b, t_ctx]
t1 = t0 + (tau_years * DAYS_PER_YEAR)
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
in_window = (
(idxs > t_ctx.unsqueeze(1))
& (time_seq >= t0.unsqueeze(1))
& (time_seq <= t1.unsqueeze(1))
& (event_seq >= 2)
& (event_seq != 0)
)
out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device)
if not in_window.any():
return out
b_idx, t_idx = in_window.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
# Filter to selected causes via a boolean membership mask over the global disease space.
selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device)
selected[cause_ids] = True
keep = selected[disease_ids]
if not keep.any():
return out
b_idx = b_idx[keep]
disease_ids = disease_ids[keep]
# Map disease_id -> local index in cause_ids
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device)
lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device)
local = lookup[disease_ids]
out[b_idx, local] = True
return out
# ============================================================
# CIF conversion
# ============================================================
def cifs_from_exponential_logits(
logits: torch.Tensor,
taus: Sequence[float],
eps: float = 1e-6,
return_survival: bool = False,
) -> torch.Tensor:
"""Convert exponential cause-specific logits -> CIFs at taus.
logits: (B, K)
returns: (B, K, H) or (cif, survival) if return_survival
"""
hazards = F.softplus(logits) + eps
total = hazards.sum(dim=1, keepdim=True) # (B,1)
taus_t = torch.tensor(list(taus), device=logits.device,
dtype=hazards.dtype).view(1, 1, -1)
total_h = total.unsqueeze(-1) # (B,1,1)
# (1 - exp(-Lambda * tau))
one_minus_surv = 1.0 - torch.exp(-total_h * taus_t)
frac = hazards / torch.clamp(total, min=eps)
cif = frac.unsqueeze(-1) * one_minus_surv # (B,K,H)
# If total==0, set to 0
cif = torch.where(total_h > 0, cif, torch.zeros_like(cif))
if not return_survival:
return cif
survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H)
survival = torch.where(total.squeeze(1) > 0, survival,
torch.ones_like(survival))
return cif, survival
def cifs_from_discrete_time_logits(
logits: torch.Tensor,
bin_edges: Sequence[float],
taus: Sequence[float],
return_survival: bool = False,
) -> torch.Tensor:
"""Convert discrete-time CIF logits -> CIFs at taus.
logits: (B, K+1, n_bins+1)
bin_edges: len=n_bins+1 (including 0 and inf)
taus: subset of finite bin edges (recommended)
returns: (B, K, H) or (cif, survival) if return_survival
"""
if logits.ndim != 3:
raise ValueError("Expected logits shape (B, K+1, n_bins+1)")
B, K_plus_1, n_bins_plus_1 = logits.shape
K = K_plus_1 - 1
edges = [float(x) for x in bin_edges]
# drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf
finite_edges = [e for e in edges[1:] if math.isfinite(e)]
n_bins = len(finite_edges)
if n_bins_plus_1 != len(edges):
raise ValueError("logits last dim must match len(bin_edges)")
probs = torch.softmax(logits, dim=1) # (B, K+1, n_bins+1)
# use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot)
hazards = probs[:, :K, 1: 1 + n_bins] # (B,K,n_bins)
p_comp = probs[:, K, 1: 1 + n_bins] # (B,n_bins)
# survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v]
ones = torch.ones((B, 1), device=logits.device, dtype=probs.dtype)
cum = torch.cumprod(p_comp, dim=1)
s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # (B, n_bins)
cif_bins = torch.cumsum(s_prev.unsqueeze(
1) * hazards, dim=2) # (B,K,n_bins)
# Robust mapping from tau -> edge index (floating-point safe).
# taus are expected to align with bin edges, but may differ slightly due to parsing/serialization.
finite_edges_arr = np.asarray(finite_edges, dtype=float)
tau_to_idx: List[int] = []
for tau in taus:
tau_f = float(tau)
if not math.isfinite(tau_f):
raise ValueError("taus must be finite for discrete-time CIF")
diffs = np.abs(finite_edges_arr - tau_f)
j = int(np.argmin(diffs))
if diffs[j] > 1e-6:
raise ValueError(
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})"
)
tau_to_idx.append(j)
idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long)
cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H)
if not return_survival:
return cif
# Survival at each horizon = prod_{u <= idx[h]} p_comp[u]
survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v]
survival = survival_bins.index_select(dim=1, index=idx) # (B,H)
return cif, survival
# ============================================================
# CIF integrity checks
# ============================================================
def check_cif_integrity(
cause_cif: np.ndarray,
horizons: Sequence[float],
*,
tol: float = 1e-6,
name: str = "",
strict: bool = False,
survival: Optional[np.ndarray] = None,
) -> Tuple[bool, List[str]]:
"""Run basic sanity checks on CIF arrays.
Args:
cause_cif: (N, K, H)
horizons: length H
tol: tolerance for inequalities
name: model name for messages
strict: if True, raise ValueError on first failure
survival: optional (N, H) survival values at the same horizons
Returns:
(integrity_ok, notes)
"""
notes: List[str] = []
model_tag = f"[{name}] " if name else ""
def _fail(msg: str) -> None:
full = model_tag + msg
if strict:
raise ValueError(full)
print("WARNING:", full)
notes.append(msg)
cif = np.asarray(cause_cif)
if cif.ndim != 3:
_fail(f"integrity: expected cause_cif ndim=3, got {cif.ndim}")
return False, notes
N, K, H = cif.shape
if H != len(horizons):
_fail(
f"integrity: horizon length mismatch (H={H}, len(horizons)={len(horizons)})")
# (5) Finite
if not np.isfinite(cif).all():
_fail("integrity: non-finite values (NaN/Inf) in cause_cif")
# (1) Range
cmin = float(np.nanmin(cif))
cmax = float(np.nanmax(cif))
if cmin < -tol:
_fail(f"integrity: range min={cmin} < -tol={-tol}")
if cmax > 1.0 + tol:
_fail(f"integrity: range max={cmax} > 1+tol={1.0+tol}")
# (2) Monotonicity in horizons (per n,k)
diffs = np.diff(cif, axis=2)
if diffs.size > 0:
if np.nanmin(diffs) < -tol:
_fail("integrity: monotonicity violated (found negative diff along horizons)")
# (3) Probability mass: sum_k CIF <= 1
mass = np.sum(cif, axis=1) # (N,H)
mass_max = float(np.nanmax(mass))
if mass_max > 1.0 + tol:
_fail(f"integrity: probability mass exceeds 1 (max sum_k={mass_max})")
# (4) Conservation with survival, if provided
if survival is None:
warn = "integrity: survival not provided; skipping conservation check"
if strict:
# still skip (requested behavior), but keep message for context
notes.append(warn)
else:
print("WARNING:", model_tag + warn)
notes.append(warn)
else:
s = np.asarray(survival, dtype=float)
if s.shape != (N, H):
_fail(
f"integrity: survival shape mismatch (got {s.shape}, expected {(N, H)})")
else:
recon = 1.0 - s
err = np.abs(recon - mass)
# Discrete-time should be very tight; exponential may accumulate slightly more numerical error.
tol_cons = max(float(tol), 1e-4)
if float(np.nanmax(err)) > tol_cons:
_fail(
f"integrity: conservation violated (max |(1-surv)-sum_cif|={float(np.nanmax(err))}, tol={tol_cons})")
ok = len([n for n in notes if not n.endswith(
"skipping conservation check")]) == 0
return ok, notes
# ============================================================
# Metrics
# ============================================================
# --- Standard fast DeLong AUC variance + CI (ties handled via midranks) ---
def compute_midrank(x: np.ndarray) -> np.ndarray:
x = np.asarray(x, dtype=float)
order = np.argsort(x)
z = x[order]
n = x.shape[0]
t = np.zeros(n, dtype=float)
i = 0
while i < n:
j = i
while j < n and z[j] == z[i]:
j += 1
t[i:j] = 0.5 * (i + j - 1) + 1.0
i = j
out = np.empty(n, dtype=float)
out[order] = t
return out
def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]:
"""Fast DeLong method for computing AUC covariance.
predictions_sorted_transposed: shape (n_classifiers, n_examples) with positive examples first.
"""
preds = np.asarray(predictions_sorted_transposed, dtype=float)
m = int(label_1_count)
n = int(preds.shape[1] - m)
if m <= 0 or n <= 0:
return np.array([float("nan")]), np.array([[float("nan")]])
pos = preds[:, :m]
neg = preds[:, m:]
tx = np.array([compute_midrank(x) for x in pos])
ty = np.array([compute_midrank(x) for x in neg])
tz = np.array([compute_midrank(x) for x in preds])
aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n)
v01 = (tz[:, :m] - tx) / n
v10 = 1.0 - (tz[:, m:] - ty) / m
if v01.shape[0] > 1:
sx = np.cov(v01)
sy = np.cov(v10)
else:
# Single-classifier case: compute row-wise variance (do not flatten).
var_v01 = float(np.var(v01, axis=1, ddof=1)[0])
var_v10 = float(np.var(v10, axis=1, ddof=1)[0])
sx = np.array([[var_v01]])
sy = np.array([[var_v10]])
delong_cov = sx / m + sy / n
return aucs, delong_cov
def calc_auc_variance(ground_truth: np.ndarray, predictions: np.ndarray) -> Tuple[float, float]:
y = np.asarray(ground_truth, dtype=int)
p = np.asarray(predictions, dtype=float)
if y.ndim != 1 or p.ndim != 1 or y.shape[0] != p.shape[0]:
raise ValueError("calc_auc_variance expects 1D arrays of equal length")
m = int(np.sum(y == 1))
n = int(np.sum(y == 0))
if m == 0 or n == 0:
return float("nan"), float("nan")
order = np.argsort(-y) # positives first
preds_sorted = p[order]
aucs, cov = fastDeLong(preds_sorted[np.newaxis, :], m)
auc = float(aucs[0])
var = float(cov[0, 0])
return auc, var
def delong_ci(ground_truth: np.ndarray, predictions: np.ndarray, alpha: float = 0.95) -> Tuple[float, float, float]:
"""Return (auc, ci_low, ci_high) using DeLong variance and normal CI."""
auc, var = calc_auc_variance(ground_truth, predictions)
if not np.isfinite(var) or var <= 0:
print("WARNING: DeLong variance is non-positive or NaN; CI set to NaN")
return float(auc), float("nan"), float("nan")
sd = math.sqrt(var)
z = statistics.NormalDist().inv_cdf(1.0 - (1.0 - float(alpha)) / 2.0)
lo = max(0.0, auc - z * sd)
hi = min(1.0, auc + z * sd)
return float(auc), float(lo), float(hi)
def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Rank-based ROC AUC via MannWhitney U statistic (ties handled by midranks).
Returns NaN for degenerate labels.
"""
y = np.asarray(y_true, dtype=int)
s = np.asarray(y_score, dtype=float)
if y.ndim != 1 or s.ndim != 1 or y.shape[0] != s.shape[0]:
raise ValueError("roc_auc_rank expects 1D arrays of equal length")
m = int(np.sum(y == 1))
n = int(np.sum(y == 0))
if m == 0 or n == 0:
return float("nan")
ranks = compute_midrank(s)
sum_pos = float(np.sum(ranks[y == 1]))
auc = (sum_pos - m * (m + 1) / 2.0) / (m * n)
return float(auc)
def bootstrap_auc_ci(
scores: np.ndarray,
labels: np.ndarray,
n_bootstrap: int,
alpha: float = 0.95,
seed: int = 0,
) -> Tuple[float, float, float]:
"""Bootstrap CI for ROC AUC (percentile)."""
rng = np.random.default_rng(int(seed))
scores = np.asarray(scores, dtype=float)
labels = np.asarray(labels, dtype=int)
n = labels.shape[0]
if n == 0 or np.all(labels == labels[0]):
print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN")
return float("nan"), float("nan"), float("nan")
auc_full = roc_auc_rank(labels, scores)
if not np.isfinite(auc_full):
print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN")
return float("nan"), float("nan"), float("nan")
aucs: List[float] = []
for _ in range(int(n_bootstrap)):
idx = rng.integers(0, n, size=n)
yb = labels[idx]
if np.all(yb == yb[0]):
continue
pb = scores[idx]
auc = roc_auc_rank(yb, pb)
if np.isfinite(auc):
aucs.append(float(auc))
if len(aucs) < 10:
print("WARNING: bootstrap AUC CI has too few valid resamples; CI set to NaN")
return float(auc_full), float("nan"), float("nan")
lo_q = (1.0 - float(alpha)) / 2.0
hi_q = 1.0 - lo_q
lo = float(np.quantile(aucs, lo_q))
hi = float(np.quantile(aucs, hi_q))
return float(auc_full), lo, hi
def brier_score(p: np.ndarray, y: np.ndarray) -> float:
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
return float(np.mean((p - y) ** 2))
def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]:
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
# guard
if p.size == 0:
return {"bins": [], "ece": float("nan"), "ici": float("nan")}
edges = np.quantile(p, np.linspace(0.0, 1.0, n_bins + 1))
# make strictly increasing where possible
edges[0] = -np.inf
edges[-1] = np.inf
bins = []
ece = 0.0
ici_accum = 0.0
n = p.shape[0]
for i in range(n_bins):
mask = (p > edges[i]) & (p <= edges[i + 1])
if not np.any(mask):
continue
p_mean = float(np.mean(p[mask]))
y_mean = float(np.mean(y[mask]))
frac = float(np.mean(mask))
bins.append({"bin": i, "p_mean": p_mean,
"y_mean": y_mean, "n": int(mask.sum())})
ece += frac * abs(p_mean - y_mean)
ici_accum += abs(p_mean - y_mean)
ici = ici_accum / max(len(bins), 1)
return {"bins": bins, "ece": float(ece), "ici": float(ici)}
def count_ever_after_context_anytime(
loader: DataLoader,
offset_years: float,
n_disease: int,
device: str,
) -> Tuple[np.ndarray, int]:
"""Count per-person ever-occurrence for each disease after the prediction context.
Returns counts[k] = number of individuals with disease k at least once after context.
"""
counts = torch.zeros((n_disease,), dtype=torch.long, device=device)
n_total_eval = 0
for batch in loader:
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
keep, t_ctx, _ = select_context_indices(
event_seq, time_seq, offset_years)
if not keep.any():
continue
n_total_eval += int(keep.sum().item())
event_seq = event_seq[keep]
t_ctx = t_ctx[keep]
B, L = event_seq.shape
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
future = (idxs > t_ctx.unsqueeze(1)) & (
event_seq >= 2) & (event_seq != 0)
if not future.any():
continue
b_idx, t_idx = future.nonzero(as_tuple=True)
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
# unique per (person, disease) to count per-person ever-occurrence
key = b_idx.to(torch.long) * int(n_disease) + disease_ids
uniq = torch.unique(key)
uniq_disease = uniq % int(n_disease)
counts.scatter_add_(0, uniq_disease, torch.ones_like(
uniq_disease, dtype=torch.long))
return counts.detach().cpu().numpy(), int(n_total_eval)
# ============================================================
# Evaluation core
# ============================================================
def instantiate_model_and_head(
cfg: Dict[str, Any],
dataset: HealthDataset,
device: str,
checkpoint_path: str = "",
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]:
model_type = str(cfg["model_type"])
loss_type = str(cfg["loss_type"])
if loss_type == "exponential":
out_dims = [dataset.n_disease]
elif loss_type == "discrete_time_cif":
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
out_dims = [dataset.n_disease + 1, len(bin_edges)]
else:
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
if model_type == "delphi_fork":
backbone = DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
).to(device)
elif model_type == "sap_delphi":
# Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path.
emb_path = cfg.get("pretrained_emb_path", None)
if emb_path in {"", None}:
emb_path = cfg.get("pretrained_emd_path", None)
if emb_path in {"", None}:
run_dir = os.path.dirname(os.path.abspath(
checkpoint_path)) if checkpoint_path else ""
print(
f"WARNING: SapDelphi pretrained embedding path missing in config "
f"(expected 'pretrained_emb_path' or 'pretrained_emd_path'). "
f"checkpoint={checkpoint_path} run_dir={run_dir}"
)
backbone = SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
pretrained_weights_path=emb_path,
freeze_embeddings=True,
).to(device)
else:
raise ValueError(f"Unsupported model_type: {model_type}")
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
return backbone, head, loss_type, bin_edges
@torch.no_grad()
def predict_cifs_for_model(
backbone: torch.nn.Module,
head: torch.nn.Module,
loss_type: str,
bin_edges: Sequence[float],
loader: DataLoader,
device: str,
offset_years: float,
eval_horizons: Sequence[float],
top_cause_ids: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Run model and produce:
Returns:
allcause_risk: (N,H)
cause_cif: (N, topK, H)
cif_full: (N, K, H)
survival: (N, H)
sex: (N,)
y_allcause_tau: (N,H)
y_cause_ever_anytime: (N, topK)
y_cause_within_tau: (N, topK, H)
y_cause_within_tau_max: (N, topK)
NOTE:
- y_cause_ever_anytime is Delphi2M-compatible case/control label.
- y_cause_within_tau_* corresponds to within-horizon labels (kept for legacy/secondary AUC).
"""
backbone.eval()
head.eval()
# We will accumulate in CPU lists, then concat.
allcause_list: List[np.ndarray] = []
cause_cif_list: List[np.ndarray] = []
cif_full_list: List[np.ndarray] = []
survival_list: List[np.ndarray] = []
sex_list: List[np.ndarray] = []
y_all_list: List[np.ndarray] = []
y_cause_ever_any_list: List[np.ndarray] = []
y_cause_within_list: List[np.ndarray] = []
y_cause_within_tau_max_list: List[np.ndarray] = []
tau_max = float(max(eval_horizons))
top_cause_ids_t = torch.tensor(
top_cause_ids, dtype=torch.long, device=device)
# Efficiency: pre-create horizons tensor once per model (on device) and vectorize comparisons.
eval_horizons_t = torch.tensor(
list(eval_horizons), device=device, dtype=torch.float32).view(1, -1)
for batch in loader:
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
keep, t_ctx, _ = select_context_indices(
event_seq, time_seq, offset_years)
if not keep.any():
continue
# filter batch
event_seq = event_seq[keep]
time_seq = time_seq[keep]
cont_feats = cont_feats[keep]
cate_feats = cate_feats[keep]
sexes_k = sexes[keep]
t_ctx = t_ctx[keep]
h = backbone(event_seq, time_seq, sexes_k,
cont_feats, cate_feats) # (B,L,D)
b = torch.arange(h.size(0), device=device)
c = h[b, t_ctx] # (B,D)
logits = head(c)
if loss_type == "exponential":
cif_full, survival = cifs_from_exponential_logits(
logits, eval_horizons, return_survival=True) # (B,K,H), (B,H)
elif loss_type == "discrete_time_cif":
cif_full, survival = cifs_from_discrete_time_logits(
# (B,K,H), (B,H)
logits, bin_edges, eval_horizons, return_survival=True)
else:
raise ValueError(f"Unsupported loss_type: {loss_type}")
allcause = cif_full.sum(dim=1) # (B,H)
cause_cif = cif_full.index_select(
dim=1, index=top_cause_ids_t) # (B,topK,H)
# outcomes
dt_next, _cause_next = next_event_after_context(
event_seq, time_seq, t_ctx)
y_all = (dt_next.view(-1, 1) <= eval_horizons_t).to(torch.float32)
# Delphi2M-compatible ever label (does not depend on horizon)
y_ever_any = multi_hot_ever_after_context_anytime(
event_seq=event_seq,
t_ctx=t_ctx,
n_disease=int(cif_full.size(1)),
)
y_ever_any_top = y_ever_any.index_select(
dim=1, index=top_cause_ids_t).to(torch.float32)
# Within-horizon labels for cause-specific CIF quality + legacy AUC
n_disease = int(cif_full.size(1))
y_within_top = torch.stack(
[
multi_hot_selected_causes_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau),
cause_ids=top_cause_ids_t,
n_disease=n_disease,
).to(torch.float32)
for tau in eval_horizons
],
dim=2,
) # (B,topK,H)
y_within_tau_max_top = multi_hot_selected_causes_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=tau_max,
cause_ids=top_cause_ids_t,
n_disease=n_disease,
).to(torch.float32)
allcause_list.append(allcause.detach().cpu().numpy())
cause_cif_list.append(cause_cif.detach().cpu().numpy())
cif_full_list.append(cif_full.detach().cpu().numpy())
survival_list.append(survival.detach().cpu().numpy())
sex_list.append(sexes_k.detach().cpu().numpy())
y_all_list.append(y_all.detach().cpu().numpy())
y_cause_ever_any_list.append(y_ever_any_top.detach().cpu().numpy())
y_cause_within_list.append(y_within_top.detach().cpu().numpy())
y_cause_within_tau_max_list.append(
y_within_tau_max_top.detach().cpu().numpy())
if not allcause_list:
raise RuntimeError(
"No valid samples for evaluation (all batches filtered out by offset).")
allcause_risk = np.concatenate(allcause_list, axis=0)
cause_cif = np.concatenate(cause_cif_list, axis=0)
cif_full = np.concatenate(cif_full_list, axis=0)
survival = np.concatenate(survival_list, axis=0)
sex = np.concatenate(sex_list, axis=0)
y_allcause = np.concatenate(y_all_list, axis=0)
y_cause_ever_any = np.concatenate(y_cause_ever_any_list, axis=0)
y_cause_within = np.concatenate(y_cause_within_list, axis=0)
y_cause_within_tau_max = np.concatenate(y_cause_within_tau_max_list, axis=0)
return allcause_risk, cause_cif, cif_full, survival, sex, y_allcause, y_cause_ever_any, y_cause_within, y_cause_within_tau_max
def pick_top_causes(y_ever: np.ndarray, top_k: int) -> np.ndarray:
counts = y_ever.sum(axis=0)
order = np.argsort(-counts)
order = order[counts[order] > 0]
return order[:top_k]
def evaluate_one_model(
model_name: str,
allcause_risk: np.ndarray,
cause_cif: np.ndarray,
sex: np.ndarray,
y_allcause: np.ndarray,
y_cause_ever_anytime: np.ndarray,
y_cause_within_tau: np.ndarray,
y_cause_within_tau_max: np.ndarray,
eval_horizons: Sequence[float],
top_cause_ids: np.ndarray,
out_rows: List[Dict[str, Any]],
calib_rows: List[Dict[str, Any]],
auc_ci_method: str,
bootstrap_n: int,
n_calib_bins: int = 10,
) -> None:
H = len(eval_horizons)
# Task B (all-cause): Brier + AUC + calibration per horizon
for h_i, tau in enumerate(eval_horizons):
p = allcause_risk[:, h_i]
y = y_allcause[:, h_i]
out_rows.append(
{
"model_name": model_name,
"metric_name": "allcause_brier",
"horizon": float(tau),
"cause": "",
"value": brier_score(p, y),
"ci_low": "",
"ci_high": "",
}
)
if auc_ci_method == "none":
auc, lo, hi = float("nan"), float("nan"), float("nan")
auc = float("nan")
elif auc_ci_method == "bootstrap":
auc, lo, hi = bootstrap_auc_ci(
p, y, n_bootstrap=bootstrap_n, alpha=0.95)
else:
auc, lo, hi = delong_ci(y, p, alpha=0.95)
out_rows.append(
{
"model_name": model_name,
"metric_name": "allcause_auc",
"horizon": float(tau),
"cause": "",
"value": auc,
"ci_low": lo,
"ci_high": hi,
}
)
cal = calibration_deciles(p, y, n_bins=n_calib_bins)
out_rows.append(
{
"model_name": model_name,
"metric_name": "allcause_ece",
"horizon": float(tau),
"cause": "",
"value": cal["ece"],
"ci_low": "",
"ci_high": "",
}
)
out_rows.append(
{
"model_name": model_name,
"metric_name": "allcause_ici",
"horizon": float(tau),
"cause": "",
"value": cal["ici"],
"ci_low": "",
"ci_high": "",
}
)
# Write calibration bins into a separate CSV (always for all-cause).
for binfo in cal.get("bins", []):
calib_rows.append(
{
"model_name": model_name,
"task": "all_cause",
"horizon": float(tau),
"cause_id": -1,
"bin_index": int(binfo["bin"]),
"p_mean": float(binfo["p_mean"]),
"y_mean": float(binfo["y_mean"]),
"n_in_bin": int(binfo["n"]),
}
)
# Stratification by sex
for s_val in [0, 1]:
m = sex == s_val
if np.sum(m) < 10:
continue
p_s = p[m]
y_s = y[m]
if auc_ci_method == "none":
auc_s, lo_s, hi_s = float("nan"), float("nan"), float("nan")
elif auc_ci_method == "bootstrap":
auc_s, lo_s, hi_s = bootstrap_auc_ci(
p_s, y_s, n_bootstrap=bootstrap_n, alpha=0.95)
else:
auc_s, lo_s, hi_s = delong_ci(y_s, p_s, alpha=0.95)
out_rows.append(
{
"model_name": model_name,
"metric_name": f"allcause_auc_sex{s_val}",
"horizon": float(tau),
"cause": "",
"value": auc_s,
"ci_low": lo_s,
"ci_high": hi_s,
}
)
# Task A (Delphi2M-compatible discrimination): per-cause AUC with EVER labels
# case/control is defined by whether the disease appears ANYTIME after context.
tau_max = float(max(eval_horizons))
p_tau_max = cause_cif[:, :, -1] # (N, topK)
for j, cause_id in enumerate(top_cause_ids.tolist()):
yk = y_cause_ever_anytime[:, j]
pk = p_tau_max[:, j]
if auc_ci_method == "none":
auc, lo, hi = float("nan"), float("nan"), float("nan")
elif auc_ci_method == "bootstrap":
auc, lo, hi = bootstrap_auc_ci(
pk, yk, n_bootstrap=bootstrap_n, alpha=0.95)
else:
auc, lo, hi = delong_ci(yk, pk, alpha=0.95)
out_rows.append(
{
"model_name": model_name,
"metric_name": "cause_auc_ever",
"horizon": tau_max,
"cause": int(cause_id),
"value": auc,
"ci_low": lo,
"ci_high": hi,
}
)
# Keep the existing tau-window AUC as a separate metric (do not remove).
for j, cause_id in enumerate(top_cause_ids.tolist()):
yk = y_cause_within_tau_max[:, j]
pk = p_tau_max[:, j]
if auc_ci_method == "none":
auc, lo, hi = float("nan"), float("nan"), float("nan")
elif auc_ci_method == "bootstrap":
auc, lo, hi = bootstrap_auc_ci(
pk, yk, n_bootstrap=bootstrap_n, alpha=0.95)
else:
auc, lo, hi = delong_ci(yk, pk, alpha=0.95)
out_rows.append(
{
"model_name": model_name,
"metric_name": "cause_auc",
"horizon": tau_max,
"cause": int(cause_id),
"value": auc,
"ci_low": lo,
"ci_high": hi,
}
)
# Task B additions: cause-specific Brier + calibration curves at tau=3.84 and 10.0
tau_targets = [3.84, 10.0]
horizon_to_idx = {float(t): i for i, t in enumerate(
[float(x) for x in eval_horizons])}
for tau in tau_targets:
if float(tau) not in horizon_to_idx:
continue
h_idx = horizon_to_idx[float(tau)]
p_tau = cause_cif[:, :, h_idx] # (N, topK)
y_tau = y_cause_within_tau[:, :, h_idx] # (N, topK)
for j, cause_id in enumerate(top_cause_ids.tolist()):
p = p_tau[:, j]
y = y_tau[:, j]
out_rows.append(
{
"model_name": model_name,
"metric_name": "cause_brier",
"horizon": float(tau),
"cause": int(cause_id),
"value": brier_score(p, y),
"ci_low": "",
"ci_high": "",
}
)
cal = calibration_deciles(p, y)
out_rows.append(
{
"model_name": model_name,
"metric_name": "cause_ece",
"horizon": float(tau),
"cause": int(cause_id),
"value": cal["ece"],
"ci_low": "",
"ci_high": "",
}
)
out_rows.append(
{
"model_name": model_name,
"metric_name": "cause_ici",
"horizon": float(tau),
"cause": int(cause_id),
"value": cal["ici"],
"ci_low": "",
"ci_high": "",
}
)
# Write cause calibration bins into separate CSV only for tau targets.
for binfo in cal.get("bins", []):
calib_rows.append(
{
"model_name": model_name,
"task": "cause_k",
"horizon": float(tau),
"cause_id": int(cause_id),
"bin_index": int(binfo["bin"]),
"p_mean": float(binfo["p_mean"]),
"y_mean": float(binfo["y_mean"]),
"n_in_bin": int(binfo["n"]),
}
)
def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None:
fieldnames = [
"model_name",
"task",
"horizon",
"cause_id",
"bin_index",
"p_mean",
"y_mean",
"n_in_bin",
]
with open(path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow(r)
def write_results_csv(path: str, rows: List[Dict[str, Any]]) -> None:
fieldnames = [
"model_name",
"metric_name",
"horizon",
"cause",
"value",
"ci_low",
"ci_high",
]
with open(path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for r in rows:
w.writerow(r)
def _make_eval_tag(split: str, offset_years: float) -> str:
"""Short tag for filenames written into run directories."""
off = f"{float(offset_years):.4f}".rstrip("0").rstrip(".")
return f"{split}_offset{off}y"
def main() -> int:
ap = argparse.ArgumentParser(
description="Unified downstream evaluation via CIFs")
ap.add_argument("--models_json", type=str, required=True,
help="Path to models list JSON")
ap.add_argument("--data_prefix", type=str,
default="ukb", help="Dataset prefix")
ap.add_argument("--split", type=str, default="test",
choices=["train", "val", "test", "all"], help="Which split to evaluate")
ap.add_argument("--offset_years", type=float, default=0.5,
help="Anti-leakage offset (years)")
ap.add_argument("--eval_horizons", type=float,
nargs="*", default=DEFAULT_EVAL_HORIZONS)
ap.add_argument("--top_k_causes", type=int, default=50)
ap.add_argument("--batch_size", type=int, default=128)
ap.add_argument("--num_workers", type=int, default=0)
ap.add_argument("--seed", type=int, default=123)
ap.add_argument("--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
ap.add_argument("--out_csv", type=str, default="eval_results.csv")
ap.add_argument("--out_meta_json", type=str, default="eval_meta.json")
# Integrity checks
ap.add_argument("--integrity_strict", action="store_true", default=False)
ap.add_argument("--integrity_tol", type=float, default=1e-6)
# AUC CI methods
ap.add_argument(
"--auc_ci_method",
type=str,
default="delong",
choices=["delong", "bootstrap", "none"],
)
ap.add_argument("--bootstrap_n", type=int, default=2000)
args = ap.parse_args()
set_deterministic(args.seed)
specs = load_models_json(args.models_json)
if not specs:
raise ValueError("No models provided")
# Determine top-K causes from the evaluation split only (model-agnostic).
first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path)
cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [
"bmi", "smoking", "alcohol"]
dataset_for_top = HealthDataset(
data_prefix=args.data_prefix, covariate_list=cov_list)
subset_for_top = build_eval_subset(
dataset_for_top,
train_ratio=float(first_cfg.get("train_ratio", 0.7)),
val_ratio=float(first_cfg.get("val_ratio", 0.15)),
seed=int(first_cfg.get("random_seed", 42)),
split=args.split,
)
loader_top = DataLoader(
subset_for_top,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=health_collate_fn,
)
counts, n_total_eval = count_ever_after_context_anytime(
loader=loader_top,
offset_years=args.offset_years,
n_disease=dataset_for_top.n_disease,
device=args.device,
)
order = np.argsort(-counts)
order = order[counts[order] > 0]
top_cause_ids = order[: args.top_k_causes]
# Record top-cause counts under Delphi2M-compatible EVER label.
top_causes_meta: List[Dict[str, Any]] = []
for k in top_cause_ids.tolist():
n_case = int(counts[int(k)])
top_causes_meta.append(
{
"cause_id": int(k),
"n_case_ever": n_case,
"n_control_ever": int(n_total_eval - n_case),
"n_total_eval": int(n_total_eval),
}
)
rows: List[Dict[str, Any]] = []
calib_rows: List[Dict[str, Any]] = []
# Track per-model integrity status for meta JSON.
integrity_meta: Dict[str, Any] = {}
# Evaluate each model
for spec in specs:
run_dir = os.path.dirname(os.path.abspath(spec.checkpoint_path))
tag = _make_eval_tag(args.split, float(args.offset_years))
# Remember list offsets so we can write per-model slices to the model's run_dir.
rows_start = len(rows)
calib_start = len(calib_rows)
cfg = load_train_config_for_checkpoint(spec.checkpoint_path)
cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [
"bmi", "smoking", "alcohol"]
dataset = HealthDataset(
data_prefix=args.data_prefix, covariate_list=cov_list)
subset = build_eval_subset(
dataset,
train_ratio=float(cfg.get("train_ratio", 0.7)),
val_ratio=float(cfg.get("val_ratio", 0.15)),
seed=int(cfg.get("random_seed", 42)),
split=args.split,
)
loader = DataLoader(
subset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=health_collate_fn,
)
backbone, head, loss_type, bin_edges = instantiate_model_and_head(
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
ckpt = torch.load(spec.checkpoint_path, map_location=args.device)
backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
head.load_state_dict(ckpt["head_state_dict"], strict=True)
(
allcause_risk,
cause_cif,
cif_full,
survival,
sex,
y_allcause,
y_cause_ever_anytime,
y_cause_within_tau,
y_cause_within_tau_max,
) = predict_cifs_for_model(
backbone,
head,
loss_type,
bin_edges,
loader,
args.device,
args.offset_years,
args.eval_horizons,
top_cause_ids,
)
# CIF integrity checks before metrics.
integrity_ok, integrity_notes = check_cif_integrity(
cif_full,
args.eval_horizons,
tol=float(args.integrity_tol),
name=spec.name,
strict=bool(args.integrity_strict),
survival=survival,
)
integrity_meta[spec.name] = {
"integrity_ok": bool(integrity_ok),
"integrity_notes": integrity_notes,
}
evaluate_one_model(
model_name=spec.name,
allcause_risk=allcause_risk,
cause_cif=cause_cif,
sex=sex,
y_allcause=y_allcause,
y_cause_ever_anytime=y_cause_ever_anytime,
y_cause_within_tau=y_cause_within_tau,
y_cause_within_tau_max=y_cause_within_tau_max,
eval_horizons=args.eval_horizons,
top_cause_ids=top_cause_ids,
out_rows=rows,
calib_rows=calib_rows,
auc_ci_method=str(args.auc_ci_method),
bootstrap_n=int(args.bootstrap_n),
)
# Optionally write top-cause counts into the main results CSV as metric rows.
for tc in top_causes_meta:
rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_case_ever",
"horizon": "",
"cause": int(tc["cause_id"]),
"value": int(tc["n_case_ever"]),
"ci_low": "",
"ci_high": "",
}
)
rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_control_ever",
"horizon": "",
"cause": int(tc["cause_id"]),
"value": int(tc["n_control_ever"]),
"ci_low": "",
"ci_high": "",
}
)
rows.append(
{
"model_name": spec.name,
"metric_name": "topcause_n_total_eval",
"horizon": "",
"cause": int(tc["cause_id"]),
"value": int(tc["n_total_eval"]),
"ci_low": "",
"ci_high": "",
}
)
# Write per-model results into the model's run directory.
model_rows = rows[rows_start:]
model_calib_rows = calib_rows[calib_start:]
model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv")
model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv")
model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json")
write_results_csv(model_out_csv, model_rows)
write_calibration_bins_csv(model_calib_csv, model_calib_rows)
model_meta = {
"model_name": spec.name,
"checkpoint_path": spec.checkpoint_path,
"run_dir": run_dir,
"split": args.split,
"offset_years": args.offset_years,
"eval_horizons": [float(x) for x in args.eval_horizons],
"top_k_causes": int(args.top_k_causes),
"top_cause_ids": top_cause_ids.tolist(),
"top_causes": top_causes_meta,
"integrity": {spec.name: integrity_meta.get(spec.name, {})},
"paths": {
"results_csv": model_out_csv,
"calibration_bins_csv": model_calib_csv,
},
}
with open(model_meta_json, "w") as f:
json.dump(model_meta, f, indent=2)
print(f"Wrote per-model results to {model_out_csv}")
write_results_csv(args.out_csv, rows)
# Write calibration curve points to a separate CSV.
out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "."
calib_csv_path = os.path.join(out_dir, "calibration_bins.csv")
write_calibration_bins_csv(calib_csv_path, calib_rows)
meta = {
"split": args.split,
"offset_years": args.offset_years,
"eval_horizons": [float(x) for x in args.eval_horizons],
"top_k_causes": int(args.top_k_causes),
"top_cause_ids": top_cause_ids.tolist(),
"top_causes": top_causes_meta,
"integrity": integrity_meta,
"notes": {
"task_a_label": "Delphi2M-compatible: disease occurs ANYTIME after context (ever in remaining sequence)",
"task_a_legacy_label": "Secondary: disease occurs within tau_max after context",
"task_b_label": "all-cause event within horizon (equivalent to next disease event within horizon)",
"warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.",
},
}
with open(args.out_meta_json, "w") as f:
json.dump(meta, f, indent=2)
print(f"Wrote {args.out_csv} with {len(rows)} rows")
print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows")
print(f"Wrote {args.out_meta_json}")
return 0
if __name__ == "__main__":
raise SystemExit(main())