2056 lines
73 KiB
Python
2056 lines
73 KiB
Python
import argparse
|
||
import csv
|
||
import json
|
||
import math
|
||
import os
|
||
import random
|
||
import statistics
|
||
from dataclasses import dataclass
|
||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from torch.utils.data import DataLoader, random_split
|
||
|
||
from dataset import HealthDataset, health_collate_fn
|
||
from model import DelphiFork, SapDelphi, SimpleHead
|
||
|
||
|
||
# ============================================================
|
||
# Constants / defaults (aligned with evaluate_prompt.md)
|
||
# ============================================================
|
||
DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")]
|
||
DEFAULT_EVAL_HORIZONS = [0.72, 1.61, 3.84, 10.0]
|
||
DAYS_PER_YEAR = 365.25
|
||
DEFAULT_DEATH_CAUSE_ID = 1256
|
||
|
||
|
||
# ============================================================
|
||
# Model specs
|
||
# ============================================================
|
||
@dataclass(frozen=True)
|
||
class ModelSpec:
|
||
name: str
|
||
model_type: str # delphi_fork | sap_delphi
|
||
loss_type: str # exponential | discrete_time_cif
|
||
full_cov: bool
|
||
checkpoint_path: str
|
||
|
||
|
||
# ============================================================
|
||
# Determinism
|
||
# ============================================================
|
||
|
||
def set_deterministic(seed: int) -> None:
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
torch.cuda.manual_seed_all(seed)
|
||
torch.backends.cudnn.deterministic = True
|
||
torch.backends.cudnn.benchmark = False
|
||
|
||
|
||
# ============================================================
|
||
# Utilities
|
||
# ============================================================
|
||
|
||
def _parse_bool(x: Any) -> bool:
|
||
if isinstance(x, bool):
|
||
return x
|
||
s = str(x).strip().lower()
|
||
if s in {"true", "1", "yes", "y"}:
|
||
return True
|
||
if s in {"false", "0", "no", "n"}:
|
||
return False
|
||
raise ValueError(f"Cannot parse boolean: {x!r}")
|
||
|
||
|
||
def load_models_json(path: str) -> List[ModelSpec]:
|
||
with open(path, "r") as f:
|
||
data = json.load(f)
|
||
if not isinstance(data, list):
|
||
raise ValueError("models_json must be a list of model entries")
|
||
|
||
specs: List[ModelSpec] = []
|
||
for row in data:
|
||
specs.append(
|
||
ModelSpec(
|
||
name=str(row["name"]),
|
||
model_type=str(row["model_type"]),
|
||
loss_type=str(row["loss_type"]),
|
||
full_cov=_parse_bool(row["full_cov"]),
|
||
checkpoint_path=str(row["checkpoint_path"]),
|
||
)
|
||
)
|
||
return specs
|
||
|
||
|
||
def load_train_config_for_checkpoint(checkpoint_path: str) -> Dict[str, Any]:
|
||
run_dir = os.path.dirname(os.path.abspath(checkpoint_path))
|
||
cfg_path = os.path.join(run_dir, "train_config.json")
|
||
with open(cfg_path, "r") as f:
|
||
cfg = json.load(f)
|
||
return cfg
|
||
|
||
|
||
def build_eval_subset(
|
||
dataset: HealthDataset,
|
||
train_ratio: float,
|
||
val_ratio: float,
|
||
seed: int,
|
||
split: str,
|
||
):
|
||
n_total = len(dataset)
|
||
n_train = int(n_total * train_ratio)
|
||
n_val = int(n_total * val_ratio)
|
||
n_test = n_total - n_train - n_val
|
||
|
||
train_ds, val_ds, test_ds = random_split(
|
||
dataset,
|
||
[n_train, n_val, n_test],
|
||
generator=torch.Generator().manual_seed(seed),
|
||
)
|
||
|
||
if split == "train":
|
||
return train_ds
|
||
if split == "val":
|
||
return val_ds
|
||
if split == "test":
|
||
return test_ds
|
||
if split == "all":
|
||
return dataset
|
||
raise ValueError("split must be one of: train, val, test, all")
|
||
|
||
|
||
# ============================================================
|
||
# Context selection (anti-leakage)
|
||
# ============================================================
|
||
|
||
def select_context_indices(
|
||
event_seq: torch.Tensor,
|
||
time_seq: torch.Tensor,
|
||
offset_years: float,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""Select per-sample prediction context index.
|
||
|
||
IMPORTANT SEMANTICS:
|
||
- The last observed token time is treated as the FOLLOW-UP END time.
|
||
- We pick the last valid token with time <= (followup_end_time - offset).
|
||
- We do NOT interpret followup_end_time as an event time.
|
||
|
||
Returns:
|
||
keep_mask: (B,) bool, which samples have a valid context
|
||
t_ctx: (B,) long, index into sequence
|
||
t_ctx_time: (B,) float, time (days) at context
|
||
"""
|
||
# valid tokens are event != 0 (padding is 0)
|
||
valid = event_seq != 0
|
||
lengths = valid.sum(dim=1)
|
||
last_idx = torch.clamp(lengths - 1, min=0)
|
||
|
||
b = torch.arange(event_seq.size(0), device=event_seq.device)
|
||
followup_end_time = time_seq[b, last_idx]
|
||
t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR)
|
||
|
||
eligible = valid & (time_seq <= t_cut.unsqueeze(1))
|
||
eligible_counts = eligible.sum(dim=1)
|
||
keep = eligible_counts > 0
|
||
|
||
t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long)
|
||
t_ctx_time = time_seq[b, t_ctx]
|
||
|
||
return keep, t_ctx, t_ctx_time
|
||
|
||
|
||
def next_event_after_context(
|
||
event_seq: torch.Tensor,
|
||
time_seq: torch.Tensor,
|
||
t_ctx: torch.Tensor,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""Return next disease event after context.
|
||
|
||
Returns:
|
||
dt_years: (B,) float, time to next disease in years; +inf if none
|
||
cause: (B,) long, disease id in [0,K) for next event; -1 if none
|
||
"""
|
||
B, L = event_seq.shape
|
||
b = torch.arange(B, device=event_seq.device)
|
||
t0 = time_seq[b, t_ctx]
|
||
|
||
# Allow same-day events while excluding the context token itself.
|
||
# We rely on time-sorted sequences and select the FIRST valid future event by index.
|
||
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
||
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
|
||
idx_min = torch.where(
|
||
future, idxs, torch.full_like(idxs, L)).min(dim=1).values
|
||
|
||
has = idx_min < L
|
||
t_next = torch.where(has, idx_min, torch.zeros_like(idx_min))
|
||
|
||
t_next_time = time_seq[b, t_next]
|
||
dt_days = t_next_time - t0
|
||
dt_years = dt_days / DAYS_PER_YEAR
|
||
dt_years = torch.where(
|
||
has, dt_years, torch.full_like(dt_years, float("inf")))
|
||
|
||
cause_token = event_seq[b, t_next]
|
||
cause = (cause_token - 2).to(torch.long)
|
||
cause = torch.where(has, cause, torch.full_like(cause, -1))
|
||
|
||
return dt_years, cause
|
||
|
||
|
||
def multi_hot_ever_within_horizon(
|
||
event_seq: torch.Tensor,
|
||
time_seq: torch.Tensor,
|
||
t_ctx: torch.Tensor,
|
||
tau_years: float,
|
||
n_disease: int,
|
||
) -> torch.Tensor:
|
||
"""Binary labels: disease k occurs within tau after context (any occurrence)."""
|
||
B, L = event_seq.shape
|
||
b = torch.arange(B, device=event_seq.device)
|
||
t0 = time_seq[b, t_ctx]
|
||
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
||
|
||
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
||
# Include same-day events after context, exclude any token at/before context index.
|
||
in_window = (
|
||
(idxs > t_ctx.unsqueeze(1))
|
||
& (time_seq >= t0.unsqueeze(1))
|
||
& (time_seq <= t1.unsqueeze(1))
|
||
& (event_seq >= 2)
|
||
& (event_seq != 0)
|
||
)
|
||
|
||
if not in_window.any():
|
||
return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||
|
||
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||
|
||
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||
y[b_idx, disease_ids] = True
|
||
return y
|
||
|
||
|
||
def multi_hot_ever_after_context_anytime(
|
||
event_seq: torch.Tensor,
|
||
t_ctx: torch.Tensor,
|
||
n_disease: int,
|
||
) -> torch.Tensor:
|
||
"""Binary labels: disease k occurs ANYTIME after the prediction context.
|
||
|
||
This is Delphi2M-compatible for Task A case/control definition.
|
||
Same-day events are included as long as they occur after the context token index.
|
||
"""
|
||
B, L = event_seq.shape
|
||
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
||
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
|
||
|
||
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||
if not future.any():
|
||
return y
|
||
|
||
b_idx, t_idx = future.nonzero(as_tuple=True)
|
||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||
y[b_idx, disease_ids] = True
|
||
return y
|
||
|
||
|
||
def multi_hot_selected_causes_within_horizon(
|
||
event_seq: torch.Tensor,
|
||
time_seq: torch.Tensor,
|
||
t_ctx: torch.Tensor,
|
||
tau_years: float,
|
||
cause_ids: torch.Tensor,
|
||
n_disease: int,
|
||
) -> torch.Tensor:
|
||
"""Labels for selected causes only: does cause k occur within tau after context?"""
|
||
B, L = event_seq.shape
|
||
device = event_seq.device
|
||
b = torch.arange(B, device=device)
|
||
t0 = time_seq[b, t_ctx]
|
||
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
||
|
||
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
||
in_window = (
|
||
(idxs > t_ctx.unsqueeze(1))
|
||
& (time_seq >= t0.unsqueeze(1))
|
||
& (time_seq <= t1.unsqueeze(1))
|
||
& (event_seq >= 2)
|
||
& (event_seq != 0)
|
||
)
|
||
|
||
out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device)
|
||
if not in_window.any():
|
||
return out
|
||
|
||
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||
|
||
# Filter to selected causes via a boolean membership mask over the global disease space.
|
||
selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device)
|
||
selected[cause_ids] = True
|
||
keep = selected[disease_ids]
|
||
if not keep.any():
|
||
return out
|
||
|
||
b_idx = b_idx[keep]
|
||
disease_ids = disease_ids[keep]
|
||
|
||
# Map disease_id -> local index in cause_ids
|
||
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
|
||
lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device)
|
||
lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device)
|
||
local = lookup[disease_ids]
|
||
out[b_idx, local] = True
|
||
return out
|
||
|
||
|
||
# ============================================================
|
||
# CIF conversion
|
||
# ============================================================
|
||
|
||
def cifs_from_exponential_logits(
|
||
logits: torch.Tensor,
|
||
taus: Sequence[float],
|
||
eps: float = 1e-6,
|
||
return_survival: bool = False,
|
||
) -> torch.Tensor:
|
||
"""Convert exponential cause-specific logits -> CIFs at taus.
|
||
|
||
logits: (B, K)
|
||
returns: (B, K, H) or (cif, survival) if return_survival
|
||
"""
|
||
hazards = F.softplus(logits) + eps
|
||
total = hazards.sum(dim=1, keepdim=True) # (B,1)
|
||
|
||
taus_t = torch.tensor(list(taus), device=logits.device,
|
||
dtype=hazards.dtype).view(1, 1, -1)
|
||
total_h = total.unsqueeze(-1) # (B,1,1)
|
||
|
||
# (1 - exp(-Lambda * tau))
|
||
one_minus_surv = 1.0 - torch.exp(-total_h * taus_t)
|
||
frac = hazards / torch.clamp(total, min=eps)
|
||
cif = frac.unsqueeze(-1) * one_minus_surv # (B,K,H)
|
||
# If total==0, set to 0
|
||
cif = torch.where(total_h > 0, cif, torch.zeros_like(cif))
|
||
|
||
if not return_survival:
|
||
return cif
|
||
|
||
survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H)
|
||
# Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors.
|
||
nonzero = (total.squeeze(1) > 0).unsqueeze(1)
|
||
survival = torch.where(nonzero, survival, torch.ones_like(survival))
|
||
return cif, survival
|
||
|
||
|
||
def cifs_from_discrete_time_logits(
|
||
logits: torch.Tensor,
|
||
bin_edges: Sequence[float],
|
||
taus: Sequence[float],
|
||
return_survival: bool = False,
|
||
) -> torch.Tensor:
|
||
"""Convert discrete-time CIF logits -> CIFs at taus.
|
||
|
||
logits: (B, K+1, n_bins+1)
|
||
bin_edges: len=n_bins+1 (including 0 and inf)
|
||
taus: subset of finite bin edges (recommended)
|
||
|
||
returns: (B, K, H) or (cif, survival) if return_survival
|
||
"""
|
||
if logits.ndim != 3:
|
||
raise ValueError("Expected logits shape (B, K+1, n_bins+1)")
|
||
|
||
B, K_plus_1, n_bins_plus_1 = logits.shape
|
||
K = K_plus_1 - 1
|
||
|
||
edges = [float(x) for x in bin_edges]
|
||
# drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf
|
||
finite_edges = [e for e in edges[1:] if math.isfinite(e)]
|
||
n_bins = len(finite_edges)
|
||
|
||
if n_bins_plus_1 != len(edges):
|
||
raise ValueError("logits last dim must match len(bin_edges)")
|
||
|
||
probs = torch.softmax(logits, dim=1) # (B, K+1, n_bins+1)
|
||
|
||
# use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot)
|
||
hazards = probs[:, :K, 1: 1 + n_bins] # (B,K,n_bins)
|
||
p_comp = probs[:, K, 1: 1 + n_bins] # (B,n_bins)
|
||
|
||
# survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v]
|
||
ones = torch.ones((B, 1), device=logits.device, dtype=probs.dtype)
|
||
cum = torch.cumprod(p_comp, dim=1)
|
||
s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # (B, n_bins)
|
||
|
||
cif_bins = torch.cumsum(s_prev.unsqueeze(
|
||
1) * hazards, dim=2) # (B,K,n_bins)
|
||
|
||
# Robust mapping from tau -> edge index (floating-point safe).
|
||
# taus are expected to align with bin edges, but may differ slightly due to parsing/serialization.
|
||
finite_edges_arr = np.asarray(finite_edges, dtype=float)
|
||
tau_to_idx: List[int] = []
|
||
for tau in taus:
|
||
tau_f = float(tau)
|
||
if not math.isfinite(tau_f):
|
||
raise ValueError("taus must be finite for discrete-time CIF")
|
||
diffs = np.abs(finite_edges_arr - tau_f)
|
||
j = int(np.argmin(diffs))
|
||
if diffs[j] > 1e-6:
|
||
raise ValueError(
|
||
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})"
|
||
)
|
||
tau_to_idx.append(j)
|
||
|
||
idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long)
|
||
cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H)
|
||
|
||
if not return_survival:
|
||
return cif
|
||
|
||
# Survival at each horizon = prod_{u <= idx[h]} p_comp[u]
|
||
survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v]
|
||
survival = survival_bins.index_select(dim=1, index=idx) # (B,H)
|
||
return cif, survival
|
||
|
||
|
||
# ============================================================
|
||
# CIF integrity checks
|
||
# ============================================================
|
||
|
||
def check_cif_integrity(
|
||
cause_cif: np.ndarray,
|
||
horizons: Sequence[float],
|
||
*,
|
||
tol: float = 1e-6,
|
||
name: str = "",
|
||
strict: bool = False,
|
||
survival: Optional[np.ndarray] = None,
|
||
) -> Tuple[bool, List[str]]:
|
||
"""Run basic sanity checks on CIF arrays.
|
||
|
||
Args:
|
||
cause_cif: (N, K, H)
|
||
horizons: length H
|
||
tol: tolerance for inequalities
|
||
name: model name for messages
|
||
strict: if True, raise ValueError on first failure
|
||
survival: optional (N, H) survival values at the same horizons
|
||
|
||
Returns:
|
||
(integrity_ok, notes)
|
||
"""
|
||
notes: List[str] = []
|
||
model_tag = f"[{name}] " if name else ""
|
||
|
||
def _fail(msg: str) -> None:
|
||
full = model_tag + msg
|
||
if strict:
|
||
raise ValueError(full)
|
||
print("WARNING:", full)
|
||
notes.append(msg)
|
||
|
||
cif = np.asarray(cause_cif)
|
||
if cif.ndim != 3:
|
||
_fail(f"integrity: expected cause_cif ndim=3, got {cif.ndim}")
|
||
return False, notes
|
||
N, K, H = cif.shape
|
||
if H != len(horizons):
|
||
_fail(
|
||
f"integrity: horizon length mismatch (H={H}, len(horizons)={len(horizons)})")
|
||
|
||
# (5) Finite
|
||
if not np.isfinite(cif).all():
|
||
_fail("integrity: non-finite values (NaN/Inf) in cause_cif")
|
||
|
||
# (1) Range
|
||
cmin = float(np.nanmin(cif))
|
||
cmax = float(np.nanmax(cif))
|
||
if cmin < -tol:
|
||
_fail(f"integrity: range min={cmin} < -tol={-tol}")
|
||
if cmax > 1.0 + tol:
|
||
_fail(f"integrity: range max={cmax} > 1+tol={1.0+tol}")
|
||
|
||
# (2) Monotonicity in horizons (per n,k)
|
||
diffs = np.diff(cif, axis=2)
|
||
if diffs.size > 0:
|
||
if np.nanmin(diffs) < -tol:
|
||
_fail("integrity: monotonicity violated (found negative diff along horizons)")
|
||
|
||
# (3) Probability mass: sum_k CIF <= 1
|
||
mass = np.sum(cif, axis=1) # (N,H)
|
||
mass_max = float(np.nanmax(mass))
|
||
if mass_max > 1.0 + tol:
|
||
_fail(f"integrity: probability mass exceeds 1 (max sum_k={mass_max})")
|
||
|
||
# (4) Conservation with survival, if provided
|
||
if survival is None:
|
||
warn = "integrity: survival not provided; skipping conservation check"
|
||
if strict:
|
||
# still skip (requested behavior), but keep message for context
|
||
notes.append(warn)
|
||
else:
|
||
print("WARNING:", model_tag + warn)
|
||
notes.append(warn)
|
||
else:
|
||
s = np.asarray(survival, dtype=float)
|
||
if s.shape != (N, H):
|
||
_fail(
|
||
f"integrity: survival shape mismatch (got {s.shape}, expected {(N, H)})")
|
||
else:
|
||
recon = 1.0 - s
|
||
err = np.abs(recon - mass)
|
||
# Discrete-time should be very tight; exponential may accumulate slightly more numerical error.
|
||
tol_cons = max(float(tol), 1e-4)
|
||
if float(np.nanmax(err)) > tol_cons:
|
||
_fail(
|
||
f"integrity: conservation violated (max |(1-surv)-sum_cif|={float(np.nanmax(err))}, tol={tol_cons})")
|
||
|
||
ok = len([n for n in notes if not n.endswith(
|
||
"skipping conservation check")]) == 0
|
||
return ok, notes
|
||
|
||
|
||
# ============================================================
|
||
# Metrics
|
||
# ============================================================
|
||
|
||
# --- Standard fast DeLong AUC variance + CI (ties handled via midranks) ---
|
||
|
||
def compute_midrank(x: np.ndarray) -> np.ndarray:
|
||
x = np.asarray(x, dtype=float)
|
||
order = np.argsort(x)
|
||
z = x[order]
|
||
n = x.shape[0]
|
||
t = np.zeros(n, dtype=float)
|
||
i = 0
|
||
while i < n:
|
||
j = i
|
||
while j < n and z[j] == z[i]:
|
||
j += 1
|
||
t[i:j] = 0.5 * (i + j - 1) + 1.0
|
||
i = j
|
||
out = np.empty(n, dtype=float)
|
||
out[order] = t
|
||
return out
|
||
|
||
|
||
def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""Fast DeLong method for computing AUC covariance.
|
||
|
||
predictions_sorted_transposed: shape (n_classifiers, n_examples) with positive examples first.
|
||
"""
|
||
preds = np.asarray(predictions_sorted_transposed, dtype=float)
|
||
m = int(label_1_count)
|
||
n = int(preds.shape[1] - m)
|
||
if m <= 0 or n <= 0:
|
||
return np.array([float("nan")]), np.array([[float("nan")]])
|
||
|
||
pos = preds[:, :m]
|
||
neg = preds[:, m:]
|
||
|
||
tx = np.array([compute_midrank(x) for x in pos])
|
||
ty = np.array([compute_midrank(x) for x in neg])
|
||
tz = np.array([compute_midrank(x) for x in preds])
|
||
|
||
aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n)
|
||
|
||
v01 = (tz[:, :m] - tx) / n
|
||
v10 = 1.0 - (tz[:, m:] - ty) / m
|
||
|
||
if v01.shape[0] > 1:
|
||
sx = np.cov(v01)
|
||
sy = np.cov(v10)
|
||
else:
|
||
# Single-classifier case: compute row-wise variance (do not flatten).
|
||
var_v01 = float(np.var(v01, axis=1, ddof=1)[0])
|
||
var_v10 = float(np.var(v10, axis=1, ddof=1)[0])
|
||
sx = np.array([[var_v01]])
|
||
sy = np.array([[var_v10]])
|
||
delong_cov = sx / m + sy / n
|
||
return aucs, delong_cov
|
||
|
||
|
||
def calc_auc_variance(ground_truth: np.ndarray, predictions: np.ndarray) -> Tuple[float, float]:
|
||
y = np.asarray(ground_truth, dtype=int)
|
||
p = np.asarray(predictions, dtype=float)
|
||
if y.ndim != 1 or p.ndim != 1 or y.shape[0] != p.shape[0]:
|
||
raise ValueError("calc_auc_variance expects 1D arrays of equal length")
|
||
|
||
m = int(np.sum(y == 1))
|
||
n = int(np.sum(y == 0))
|
||
if m == 0 or n == 0:
|
||
return float("nan"), float("nan")
|
||
|
||
order = np.argsort(-y) # positives first
|
||
preds_sorted = p[order]
|
||
aucs, cov = fastDeLong(preds_sorted[np.newaxis, :], m)
|
||
auc = float(aucs[0])
|
||
var = float(cov[0, 0])
|
||
return auc, var
|
||
|
||
|
||
def delong_ci(ground_truth: np.ndarray, predictions: np.ndarray, alpha: float = 0.95) -> Tuple[float, float, float]:
|
||
"""Return (auc, ci_low, ci_high) using DeLong variance and normal CI."""
|
||
auc, var = calc_auc_variance(ground_truth, predictions)
|
||
if not np.isfinite(var) or var <= 0:
|
||
print("WARNING: DeLong variance is non-positive or NaN; CI set to NaN")
|
||
return float(auc), float("nan"), float("nan")
|
||
|
||
sd = math.sqrt(var)
|
||
z = statistics.NormalDist().inv_cdf(1.0 - (1.0 - float(alpha)) / 2.0)
|
||
lo = max(0.0, auc - z * sd)
|
||
hi = min(1.0, auc + z * sd)
|
||
return float(auc), float(lo), float(hi)
|
||
|
||
|
||
def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||
"""Rank-based ROC AUC via Mann–Whitney U statistic (ties handled by midranks).
|
||
|
||
Returns NaN for degenerate labels.
|
||
"""
|
||
y = np.asarray(y_true, dtype=int)
|
||
s = np.asarray(y_score, dtype=float)
|
||
if y.ndim != 1 or s.ndim != 1 or y.shape[0] != s.shape[0]:
|
||
raise ValueError("roc_auc_rank expects 1D arrays of equal length")
|
||
m = int(np.sum(y == 1))
|
||
n = int(np.sum(y == 0))
|
||
if m == 0 or n == 0:
|
||
return float("nan")
|
||
|
||
ranks = compute_midrank(s)
|
||
sum_pos = float(np.sum(ranks[y == 1]))
|
||
auc = (sum_pos - m * (m + 1) / 2.0) / (m * n)
|
||
return float(auc)
|
||
|
||
|
||
def bootstrap_auc_ci(
|
||
scores: np.ndarray,
|
||
labels: np.ndarray,
|
||
n_bootstrap: int,
|
||
alpha: float = 0.95,
|
||
seed: int = 0,
|
||
) -> Tuple[float, float, float]:
|
||
"""Bootstrap CI for ROC AUC (percentile)."""
|
||
rng = np.random.default_rng(int(seed))
|
||
scores = np.asarray(scores, dtype=float)
|
||
labels = np.asarray(labels, dtype=int)
|
||
n = labels.shape[0]
|
||
if n == 0 or np.all(labels == labels[0]):
|
||
print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN")
|
||
return float("nan"), float("nan"), float("nan")
|
||
|
||
auc_full = roc_auc_rank(labels, scores)
|
||
if not np.isfinite(auc_full):
|
||
print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN")
|
||
return float("nan"), float("nan"), float("nan")
|
||
|
||
aucs: List[float] = []
|
||
for _ in range(int(n_bootstrap)):
|
||
idx = rng.integers(0, n, size=n)
|
||
yb = labels[idx]
|
||
if np.all(yb == yb[0]):
|
||
continue
|
||
pb = scores[idx]
|
||
auc = roc_auc_rank(yb, pb)
|
||
if np.isfinite(auc):
|
||
aucs.append(float(auc))
|
||
|
||
if len(aucs) < 10:
|
||
print("WARNING: bootstrap AUC CI has too few valid resamples; CI set to NaN")
|
||
return float(auc_full), float("nan"), float("nan")
|
||
|
||
lo_q = (1.0 - float(alpha)) / 2.0
|
||
hi_q = 1.0 - lo_q
|
||
lo = float(np.quantile(aucs, lo_q))
|
||
hi = float(np.quantile(aucs, hi_q))
|
||
return float(auc_full), lo, hi
|
||
|
||
|
||
def brier_score(p: np.ndarray, y: np.ndarray) -> float:
|
||
p = np.asarray(p, dtype=float)
|
||
y = np.asarray(y, dtype=float)
|
||
return float(np.mean((p - y) ** 2))
|
||
|
||
|
||
def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]:
|
||
p = np.asarray(p, dtype=float)
|
||
y = np.asarray(y, dtype=float)
|
||
|
||
# guard
|
||
if p.size == 0:
|
||
return {"bins": [], "ici": float("nan")}
|
||
|
||
edges = np.quantile(p, np.linspace(0.0, 1.0, n_bins + 1))
|
||
# make strictly increasing where possible
|
||
edges[0] = -np.inf
|
||
edges[-1] = np.inf
|
||
|
||
bins = []
|
||
ici_accum = 0.0
|
||
n = p.shape[0]
|
||
|
||
for i in range(n_bins):
|
||
mask = (p > edges[i]) & (p <= edges[i + 1])
|
||
if not np.any(mask):
|
||
continue
|
||
p_mean = float(np.mean(p[mask]))
|
||
y_mean = float(np.mean(y[mask]))
|
||
bins.append({"bin": i, "p_mean": p_mean,
|
||
"y_mean": y_mean, "n": int(mask.sum())})
|
||
ici_accum += abs(p_mean - y_mean)
|
||
|
||
ici = ici_accum / max(len(bins), 1)
|
||
return {"bins": bins, "ici": float(ici)}
|
||
|
||
|
||
def _safe_float(x: Any, default: float = float("nan")) -> float:
|
||
try:
|
||
return float(x)
|
||
except Exception:
|
||
return float(default)
|
||
|
||
|
||
def _ensure_dir(path: str) -> None:
|
||
os.makedirs(path, exist_ok=True)
|
||
|
||
|
||
def load_cause_names(path: str = "labels.csv") -> Dict[int, str]:
|
||
"""Load 0-based cause_id -> name mapping.
|
||
|
||
labels.csv is assumed to be one label per line, in disease-id order.
|
||
"""
|
||
if not os.path.exists(path):
|
||
return {}
|
||
mapping: Dict[int, str] = {}
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
for i, line in enumerate(f):
|
||
name = line.strip()
|
||
if name:
|
||
mapping[int(i)] = name
|
||
return mapping
|
||
|
||
|
||
def pick_focus_causes(
|
||
*,
|
||
counts_within_tau: Optional[np.ndarray],
|
||
n_disease: int,
|
||
death_cause_id: int = DEFAULT_DEATH_CAUSE_ID,
|
||
k: int = 5,
|
||
) -> List[int]:
|
||
"""Pick focus causes for user-facing evaluation.
|
||
|
||
Rule:
|
||
1) Always include death_cause_id first.
|
||
2) Then add K additional causes by descending event count if available.
|
||
If counts_within_tau is None, fall back to descending cause_id coverage proxy.
|
||
|
||
Notes:
|
||
- counts_within_tau is expected to be shape (n_disease,).
|
||
- Deterministic: ties broken by smaller cause id.
|
||
"""
|
||
n_disease_i = int(n_disease)
|
||
if death_cause_id < 0 or death_cause_id >= n_disease_i:
|
||
print(
|
||
f"WARNING: death_cause_id={death_cause_id} out of range (n_disease={n_disease_i}); "
|
||
"it will be omitted from focus causes."
|
||
)
|
||
focus: List[int] = []
|
||
else:
|
||
focus = [int(death_cause_id)]
|
||
|
||
candidates = [i for i in range(n_disease_i) if i != int(death_cause_id)]
|
||
|
||
if counts_within_tau is not None:
|
||
c = np.asarray(counts_within_tau).astype(float)
|
||
if c.shape[0] != n_disease_i:
|
||
print(
|
||
"WARNING: counts_within_tau length mismatch; falling back to coverage proxy ordering."
|
||
)
|
||
counts_within_tau = None
|
||
else:
|
||
# Sort by (-count, cause_id)
|
||
order = sorted(candidates, key=lambda i: (-float(c[i]), int(i)))
|
||
order = [i for i in order if float(c[i]) > 0]
|
||
focus.extend([int(i) for i in order[: int(k)]])
|
||
|
||
if counts_within_tau is None:
|
||
# Fallback: deterministic coverage proxy (descending id, excluding death), then take K.
|
||
# (Real coverage requires data; this path is mostly for robustness.)
|
||
order = sorted(candidates, key=lambda i: (-int(i)))
|
||
focus.extend([int(i) for i in order[: int(k)]])
|
||
|
||
# De-dup while preserving order
|
||
seen = set()
|
||
out: List[int] = []
|
||
for cid in focus:
|
||
if cid not in seen:
|
||
out.append(cid)
|
||
seen.add(cid)
|
||
return out
|
||
|
||
|
||
def write_simple_csv(path: str, fieldnames: List[str], rows: List[Dict[str, Any]]) -> None:
|
||
_ensure_dir(os.path.dirname(os.path.abspath(path)) or ".")
|
||
with open(path, "w", newline="", encoding="utf-8") as f:
|
||
w = csv.DictWriter(f, fieldnames=fieldnames)
|
||
w.writeheader()
|
||
for r in rows:
|
||
w.writerow(r)
|
||
|
||
|
||
def _sex_slices(sex: Optional[np.ndarray]) -> List[Tuple[str, Optional[np.ndarray]]]:
|
||
"""Return list of (sex_label, mask) slices including an 'all' slice.
|
||
|
||
If sex is missing, returns only ('all', None).
|
||
"""
|
||
out: List[Tuple[str, Optional[np.ndarray]]] = [("all", None)]
|
||
if sex is None:
|
||
return out
|
||
s = np.asarray(sex)
|
||
if s.ndim != 1:
|
||
return out
|
||
for val in [0, 1]:
|
||
m = (s == val)
|
||
if int(np.sum(m)) > 0:
|
||
out.append((str(val), m))
|
||
return out
|
||
|
||
|
||
def _quantile_edges(p: np.ndarray, q: int) -> np.ndarray:
|
||
edges = np.quantile(p, np.linspace(0.0, 1.0, int(q) + 1))
|
||
edges = np.asarray(edges, dtype=float)
|
||
edges[0] = -np.inf
|
||
edges[-1] = np.inf
|
||
return edges
|
||
|
||
|
||
def compute_risk_stratification_bins(
|
||
p: np.ndarray,
|
||
y: np.ndarray,
|
||
*,
|
||
q_default: int = 10,
|
||
) -> Tuple[int, List[Dict[str, Any]], Dict[str, Any]]:
|
||
"""Compute quantile-based risk strata and a compact summary."""
|
||
p = np.asarray(p, dtype=float)
|
||
y = np.asarray(y, dtype=float)
|
||
n = int(p.shape[0])
|
||
if n == 0:
|
||
return 0, [], {
|
||
"y_overall": float("nan"),
|
||
"top_decile_y_rate": float("nan"),
|
||
"bottom_half_y_rate": float("nan"),
|
||
"lift_top10_vs_bottom50": float("nan"),
|
||
"slope_pred_vs_obs": float("nan"),
|
||
}
|
||
|
||
# Choose quantiles robustly.
|
||
q = int(q_default)
|
||
if n < 200:
|
||
q = 5
|
||
|
||
edges = _quantile_edges(p, q)
|
||
y_overall = float(np.mean(y))
|
||
bin_rows: List[Dict[str, Any]] = []
|
||
p_means: List[float] = []
|
||
y_rates: List[float] = []
|
||
n_bins: List[int] = []
|
||
|
||
for i in range(q):
|
||
mask = (p > edges[i]) & (p <= edges[i + 1])
|
||
nb = int(np.sum(mask))
|
||
if nb == 0:
|
||
# Keep the row for consistent plotting; set NaNs.
|
||
bin_rows.append(
|
||
{
|
||
"q": int(i + 1),
|
||
"n_bin": 0,
|
||
"p_mean": float("nan"),
|
||
"y_rate": float("nan"),
|
||
"y_overall": y_overall,
|
||
"lift_vs_overall": float("nan"),
|
||
}
|
||
)
|
||
continue
|
||
p_mean = float(np.mean(p[mask]))
|
||
y_rate = float(np.mean(y[mask]))
|
||
lift = float(y_rate / y_overall) if y_overall > 0 else float("nan")
|
||
bin_rows.append(
|
||
{
|
||
"q": int(i + 1),
|
||
"n_bin": nb,
|
||
"p_mean": p_mean,
|
||
"y_rate": y_rate,
|
||
"y_overall": y_overall,
|
||
"lift_vs_overall": lift,
|
||
}
|
||
)
|
||
p_means.append(p_mean)
|
||
y_rates.append(y_rate)
|
||
n_bins.append(nb)
|
||
|
||
# Summary
|
||
top_mask = (p > edges[q - 1]) & (p <= edges[q])
|
||
bot_half_mask = (p > edges[0]) & (p <= edges[q // 2])
|
||
top_y = float(np.mean(y[top_mask])) if int(
|
||
np.sum(top_mask)) > 0 else float("nan")
|
||
bot_y = float(np.mean(y[bot_half_mask])) if int(
|
||
np.sum(bot_half_mask)) > 0 else float("nan")
|
||
lift_top_vs_bottom = float(top_y / bot_y) if (np.isfinite(top_y)
|
||
and np.isfinite(bot_y) and bot_y > 0) else float("nan")
|
||
|
||
slope = float("nan")
|
||
if len(p_means) >= 2:
|
||
# Weighted least squares slope of y_rate ~ p_mean.
|
||
x = np.asarray(p_means, dtype=float)
|
||
yy = np.asarray(y_rates, dtype=float)
|
||
w = np.asarray(n_bins, dtype=float)
|
||
xm = float(np.average(x, weights=w))
|
||
ym = float(np.average(yy, weights=w))
|
||
denom = float(np.sum(w * (x - xm) ** 2))
|
||
if denom > 0:
|
||
slope = float(np.sum(w * (x - xm) * (yy - ym)) / denom)
|
||
|
||
summary = {
|
||
"y_overall": y_overall,
|
||
"top_decile_y_rate": top_y,
|
||
"bottom_half_y_rate": bot_y,
|
||
"lift_top10_vs_bottom50": lift_top_vs_bottom,
|
||
"slope_pred_vs_obs": slope,
|
||
}
|
||
return q, bin_rows, summary
|
||
|
||
|
||
def compute_capture_points(
|
||
p: np.ndarray,
|
||
y: np.ndarray,
|
||
k_pcts: Sequence[int],
|
||
) -> List[Dict[str, Any]]:
|
||
p = np.asarray(p, dtype=float)
|
||
y = np.asarray(y, dtype=float)
|
||
n = int(p.shape[0])
|
||
if n == 0:
|
||
return []
|
||
order = np.argsort(-p)
|
||
y_sorted = y[order]
|
||
events_total = float(np.sum(y_sorted))
|
||
|
||
rows: List[Dict[str, Any]] = []
|
||
for k in k_pcts:
|
||
kf = float(k)
|
||
n_targeted = int(math.ceil(n * kf / 100.0))
|
||
n_targeted = max(1, min(n_targeted, n))
|
||
events_targeted = float(np.sum(y_sorted[:n_targeted]))
|
||
capture = float(events_targeted /
|
||
events_total) if events_total > 0 else float("nan")
|
||
precision = float(events_targeted / float(n_targeted))
|
||
rows.append(
|
||
{
|
||
"k_pct": int(k),
|
||
"n_targeted": int(n_targeted),
|
||
"events_targeted": float(events_targeted),
|
||
"events_total": float(events_total),
|
||
"event_capture_rate": capture,
|
||
"precision_in_targeted": precision,
|
||
}
|
||
)
|
||
return rows
|
||
|
||
|
||
def make_horizon_groups(horizons: Sequence[float]) -> Tuple[List[Dict[str, Any]], Dict[float, str], str]:
|
||
"""Bucketize horizons into short/medium/long using the continuous-horizon rule."""
|
||
uniq = sorted({float(h) for h in horizons})
|
||
mapping: Dict[float, str] = {}
|
||
rows: List[Dict[str, Any]] = []
|
||
# First 4 short, next 4 medium, rest long.
|
||
for i, h in enumerate(uniq):
|
||
if i < 4:
|
||
g, gr = "short", 1
|
||
elif i < 8:
|
||
g, gr = "medium", 2
|
||
else:
|
||
g, gr = "long", 3
|
||
mapping[float(h)] = g
|
||
rows.append({"horizon": float(h), "group": g, "group_rank": int(gr)})
|
||
method = "continuous_unique_horizons_first4_next4_rest"
|
||
return rows, mapping, method
|
||
|
||
|
||
def count_occurs_within_horizon(
|
||
loader: DataLoader,
|
||
offset_years: float,
|
||
tau_years: float,
|
||
n_disease: int,
|
||
device: str,
|
||
) -> Tuple[np.ndarray, int]:
|
||
"""Count per-person occurrence within tau after the prediction context.
|
||
|
||
Returns counts[k] = number of individuals with disease k at least once in (t_ctx, t_ctx+tau].
|
||
"""
|
||
counts = torch.zeros((n_disease,), dtype=torch.long, device=device)
|
||
n_total_eval = 0
|
||
for batch in loader:
|
||
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
||
event_seq = event_seq.to(device)
|
||
time_seq = time_seq.to(device)
|
||
|
||
keep, t_ctx, _ = select_context_indices(
|
||
event_seq, time_seq, offset_years)
|
||
if not keep.any():
|
||
continue
|
||
|
||
n_total_eval += int(keep.sum().item())
|
||
event_seq = event_seq[keep]
|
||
time_seq = time_seq[keep]
|
||
t_ctx = t_ctx[keep]
|
||
|
||
B, L = event_seq.shape
|
||
b = torch.arange(B, device=device)
|
||
t0 = time_seq[b, t_ctx]
|
||
t1 = t0 + (float(tau_years) * DAYS_PER_YEAR)
|
||
|
||
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
||
in_window = (
|
||
(idxs > t_ctx.unsqueeze(1))
|
||
& (time_seq >= t0.unsqueeze(1))
|
||
& (time_seq <= t1.unsqueeze(1))
|
||
& (event_seq >= 2)
|
||
& (event_seq != 0)
|
||
)
|
||
if not in_window.any():
|
||
continue
|
||
|
||
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||
|
||
# unique per (person, disease) to count per-person within-window occurrence
|
||
key = b_idx.to(torch.long) * int(n_disease) + disease_ids
|
||
uniq = torch.unique(key)
|
||
uniq_disease = uniq % int(n_disease)
|
||
counts.scatter_add_(0, uniq_disease, torch.ones_like(
|
||
uniq_disease, dtype=torch.long))
|
||
|
||
return counts.detach().cpu().numpy(), int(n_total_eval)
|
||
|
||
|
||
# ============================================================
|
||
# Evaluation core
|
||
# ============================================================
|
||
|
||
def instantiate_model_and_head(
|
||
cfg: Dict[str, Any],
|
||
dataset: HealthDataset,
|
||
device: str,
|
||
checkpoint_path: str = "",
|
||
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]:
|
||
model_type = str(cfg["model_type"])
|
||
loss_type = str(cfg["loss_type"])
|
||
|
||
if loss_type == "exponential":
|
||
out_dims = [dataset.n_disease]
|
||
elif loss_type == "discrete_time_cif":
|
||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||
out_dims = [dataset.n_disease + 1, len(bin_edges)]
|
||
else:
|
||
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
|
||
|
||
if model_type == "delphi_fork":
|
||
backbone = DelphiFork(
|
||
n_disease=dataset.n_disease,
|
||
n_tech_tokens=2,
|
||
n_embd=int(cfg["n_embd"]),
|
||
n_head=int(cfg["n_head"]),
|
||
n_layer=int(cfg["n_layer"]),
|
||
pdrop=float(cfg.get("pdrop", 0.0)),
|
||
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
|
||
n_cont=dataset.n_cont,
|
||
n_cate=dataset.n_cate,
|
||
cate_dims=dataset.cate_dims,
|
||
).to(device)
|
||
elif model_type == "sap_delphi":
|
||
# Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path.
|
||
emb_path = cfg.get("pretrained_emb_path", None)
|
||
if emb_path in {"", None}:
|
||
emb_path = cfg.get("pretrained_emd_path", None)
|
||
if emb_path in {"", None}:
|
||
run_dir = os.path.dirname(os.path.abspath(
|
||
checkpoint_path)) if checkpoint_path else ""
|
||
print(
|
||
f"WARNING: SapDelphi pretrained embedding path missing in config "
|
||
f"(expected 'pretrained_emb_path' or 'pretrained_emd_path'). "
|
||
f"checkpoint={checkpoint_path} run_dir={run_dir}"
|
||
)
|
||
backbone = SapDelphi(
|
||
n_disease=dataset.n_disease,
|
||
n_tech_tokens=2,
|
||
n_embd=int(cfg["n_embd"]),
|
||
n_head=int(cfg["n_head"]),
|
||
n_layer=int(cfg["n_layer"]),
|
||
pdrop=float(cfg.get("pdrop", 0.0)),
|
||
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
|
||
n_cont=dataset.n_cont,
|
||
n_cate=dataset.n_cate,
|
||
cate_dims=dataset.cate_dims,
|
||
pretrained_weights_path=emb_path,
|
||
freeze_embeddings=True,
|
||
).to(device)
|
||
else:
|
||
raise ValueError(f"Unsupported model_type: {model_type}")
|
||
|
||
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
|
||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||
return backbone, head, loss_type, bin_edges
|
||
|
||
|
||
@torch.no_grad()
|
||
def predict_cifs_for_model(
|
||
backbone: torch.nn.Module,
|
||
head: torch.nn.Module,
|
||
loss_type: str,
|
||
bin_edges: Sequence[float],
|
||
loader: DataLoader,
|
||
device: str,
|
||
offset_years: float,
|
||
eval_horizons: Sequence[float],
|
||
top_cause_ids: np.ndarray,
|
||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||
"""Run model and produce cause-specific, time-dependent CIF outputs.
|
||
|
||
Returns:
|
||
cause_cif: (N, topK, H)
|
||
cif_full: (N, K, H)
|
||
survival: (N, H)
|
||
y_cause_within_tau: (N, topK, H)
|
||
|
||
NOTE: Evaluation is cause-specific and horizon-specific (multi-disease risk).
|
||
"""
|
||
backbone.eval()
|
||
head.eval()
|
||
|
||
# We will accumulate in CPU lists, then concat.
|
||
cause_cif_list: List[np.ndarray] = []
|
||
cif_full_list: List[np.ndarray] = []
|
||
survival_list: List[np.ndarray] = []
|
||
y_cause_within_list: List[np.ndarray] = []
|
||
sex_list: List[np.ndarray] = []
|
||
top_cause_ids_t = torch.tensor(
|
||
top_cause_ids, dtype=torch.long, device=device)
|
||
|
||
for batch in loader:
|
||
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
||
event_seq = event_seq.to(device)
|
||
time_seq = time_seq.to(device)
|
||
cont_feats = cont_feats.to(device)
|
||
cate_feats = cate_feats.to(device)
|
||
sexes = sexes.to(device)
|
||
|
||
keep, t_ctx, _ = select_context_indices(
|
||
event_seq, time_seq, offset_years)
|
||
if not keep.any():
|
||
continue
|
||
|
||
# filter batch
|
||
event_seq = event_seq[keep]
|
||
time_seq = time_seq[keep]
|
||
cont_feats = cont_feats[keep]
|
||
cate_feats = cate_feats[keep]
|
||
sexes_k = sexes[keep]
|
||
t_ctx = t_ctx[keep]
|
||
|
||
h = backbone(event_seq, time_seq, sexes_k,
|
||
cont_feats, cate_feats) # (B,L,D)
|
||
b = torch.arange(h.size(0), device=device)
|
||
c = h[b, t_ctx] # (B,D)
|
||
logits = head(c)
|
||
|
||
if loss_type == "exponential":
|
||
cif_full, survival = cifs_from_exponential_logits(
|
||
logits, eval_horizons, return_survival=True) # (B,K,H), (B,H)
|
||
elif loss_type == "discrete_time_cif":
|
||
cif_full, survival = cifs_from_discrete_time_logits(
|
||
# (B,K,H), (B,H)
|
||
logits, bin_edges, eval_horizons, return_survival=True)
|
||
else:
|
||
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
||
|
||
cause_cif = cif_full.index_select(
|
||
dim=1, index=top_cause_ids_t) # (B,topK,H)
|
||
|
||
# Within-horizon labels for cause-specific CIF quality + discrimination.
|
||
n_disease = int(cif_full.size(1))
|
||
y_within_top = torch.stack(
|
||
[
|
||
multi_hot_selected_causes_within_horizon(
|
||
event_seq=event_seq,
|
||
time_seq=time_seq,
|
||
t_ctx=t_ctx,
|
||
tau_years=float(tau),
|
||
cause_ids=top_cause_ids_t,
|
||
n_disease=n_disease,
|
||
).to(torch.float32)
|
||
for tau in eval_horizons
|
||
],
|
||
dim=2,
|
||
) # (B,topK,H)
|
||
|
||
cause_cif_list.append(cause_cif.detach().cpu().numpy())
|
||
cif_full_list.append(cif_full.detach().cpu().numpy())
|
||
survival_list.append(survival.detach().cpu().numpy())
|
||
y_cause_within_list.append(y_within_top.detach().cpu().numpy())
|
||
sex_list.append(sexes_k.detach().cpu().numpy())
|
||
|
||
if not cause_cif_list:
|
||
raise RuntimeError(
|
||
"No valid samples for evaluation (all batches filtered out by offset).")
|
||
|
||
cause_cif = np.concatenate(cause_cif_list, axis=0)
|
||
cif_full = np.concatenate(cif_full_list, axis=0)
|
||
survival = np.concatenate(survival_list, axis=0)
|
||
y_cause_within = np.concatenate(y_cause_within_list, axis=0)
|
||
sex = np.concatenate(
|
||
sex_list, axis=0) if sex_list else np.array([], dtype=int)
|
||
|
||
return cause_cif, cif_full, survival, y_cause_within, sex
|
||
|
||
|
||
def pick_top_causes(y_ever: np.ndarray, top_k: int) -> np.ndarray:
|
||
counts = y_ever.sum(axis=0)
|
||
order = np.argsort(-counts)
|
||
order = order[counts[order] > 0]
|
||
return order[:top_k]
|
||
|
||
|
||
def evaluate_one_model(
|
||
model_name: str,
|
||
cause_cif: np.ndarray,
|
||
y_cause_within_tau: np.ndarray,
|
||
eval_horizons: Sequence[float],
|
||
top_cause_ids: np.ndarray,
|
||
out_rows: List[Dict[str, Any]],
|
||
calib_rows: List[Dict[str, Any]],
|
||
auc_ci_method: str,
|
||
bootstrap_n: int,
|
||
n_calib_bins: int = 10,
|
||
) -> None:
|
||
# Cause-specific, time-dependent metrics per horizon.
|
||
for h_i, tau in enumerate(eval_horizons):
|
||
p_tau = cause_cif[:, :, h_i] # (N, topK)
|
||
y_tau = y_cause_within_tau[:, :, h_i] # (N, topK)
|
||
|
||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||
p = p_tau[:, j]
|
||
y = y_tau[:, j]
|
||
|
||
# Primary: CIF-based Brier score + ICI (calibration).
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "cause_brier",
|
||
"horizon": float(tau),
|
||
"cause": int(cause_id),
|
||
"value": brier_score(p, y),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
|
||
cal = calibration_deciles(p, y, n_bins=n_calib_bins)
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "cause_ici",
|
||
"horizon": float(tau),
|
||
"cause": int(cause_id),
|
||
"value": cal["ici"],
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
|
||
# Secondary: discrimination via AUC at the same horizon.
|
||
if auc_ci_method == "none":
|
||
auc, lo, hi = float("nan"), float("nan"), float("nan")
|
||
elif auc_ci_method == "bootstrap":
|
||
auc, lo, hi = bootstrap_auc_ci(
|
||
p, y, n_bootstrap=bootstrap_n, alpha=0.95)
|
||
else:
|
||
auc, lo, hi = delong_ci(y, p, alpha=0.95)
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "cause_auc",
|
||
"horizon": float(tau),
|
||
"cause": int(cause_id),
|
||
"value": auc,
|
||
"ci_low": lo,
|
||
"ci_high": hi,
|
||
}
|
||
)
|
||
|
||
# Calibration curve bins for this cause + horizon.
|
||
for binfo in cal.get("bins", []):
|
||
calib_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"task": "cause_k",
|
||
"horizon": float(tau),
|
||
"cause_id": int(cause_id),
|
||
"bin_index": int(binfo["bin"]),
|
||
"p_mean": float(binfo["p_mean"]),
|
||
"y_mean": float(binfo["y_mean"]),
|
||
"n_in_bin": int(binfo["n"]),
|
||
}
|
||
)
|
||
|
||
|
||
def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None:
|
||
fieldnames = [
|
||
"model_name",
|
||
"task",
|
||
"horizon",
|
||
"cause_id",
|
||
"bin_index",
|
||
"p_mean",
|
||
"y_mean",
|
||
"n_in_bin",
|
||
]
|
||
with open(path, "w", newline="") as f:
|
||
w = csv.DictWriter(f, fieldnames=fieldnames)
|
||
w.writeheader()
|
||
for r in rows:
|
||
w.writerow(r)
|
||
|
||
|
||
def write_results_csv(path: str, rows: List[Dict[str, Any]]) -> None:
|
||
fieldnames = [
|
||
"model_name",
|
||
"metric_name",
|
||
"horizon",
|
||
"cause",
|
||
"value",
|
||
"ci_low",
|
||
"ci_high",
|
||
]
|
||
with open(path, "w", newline="") as f:
|
||
w = csv.DictWriter(f, fieldnames=fieldnames)
|
||
w.writeheader()
|
||
for r in rows:
|
||
w.writerow(r)
|
||
|
||
|
||
def _make_eval_tag(split: str, offset_years: float) -> str:
|
||
"""Short tag for filenames written into run directories."""
|
||
off = f"{float(offset_years):.4f}".rstrip("0").rstrip(".")
|
||
return f"{split}_offset{off}y"
|
||
|
||
|
||
def main() -> int:
|
||
ap = argparse.ArgumentParser(
|
||
description="Unified downstream evaluation via CIFs")
|
||
ap.add_argument("--models_json", type=str, required=True,
|
||
help="Path to models list JSON")
|
||
ap.add_argument("--data_prefix", type=str,
|
||
default="ukb", help="Dataset prefix")
|
||
ap.add_argument("--split", type=str, default="test",
|
||
choices=["train", "val", "test", "all"], help="Which split to evaluate")
|
||
ap.add_argument("--offset_years", type=float, default=0.5,
|
||
help="Anti-leakage offset (years)")
|
||
ap.add_argument("--eval_horizons", type=float,
|
||
nargs="*", default=DEFAULT_EVAL_HORIZONS)
|
||
ap.add_argument("--top_k_causes", type=int, default=50)
|
||
ap.add_argument("--batch_size", type=int, default=128)
|
||
ap.add_argument("--num_workers", type=int, default=0)
|
||
ap.add_argument("--seed", type=int, default=123)
|
||
ap.add_argument("--device", type=str,
|
||
default="cuda" if torch.cuda.is_available() else "cpu")
|
||
ap.add_argument("--out_csv", type=str, default="eval_results.csv")
|
||
ap.add_argument("--out_meta_json", type=str, default="eval_meta.json")
|
||
|
||
# Integrity checks
|
||
ap.add_argument("--integrity_strict", action="store_true", default=False)
|
||
ap.add_argument("--integrity_tol", type=float, default=1e-6)
|
||
|
||
# AUC CI methods
|
||
ap.add_argument(
|
||
"--auc_ci_method",
|
||
type=str,
|
||
default="delong",
|
||
choices=["delong", "bootstrap", "none"],
|
||
)
|
||
ap.add_argument("--bootstrap_n", type=int, default=2000)
|
||
|
||
# Export settings for user-facing experiments
|
||
ap.add_argument("--export_dir", type=str, default="eval_exports")
|
||
ap.add_argument("--death_cause_id", type=int,
|
||
default=DEFAULT_DEATH_CAUSE_ID)
|
||
ap.add_argument("--focus_k", type=int, default=5,
|
||
help="Additional non-death causes to include")
|
||
ap.add_argument("--capture_k_pcts", type=int,
|
||
nargs="*", default=[1, 5, 10, 20])
|
||
ap.add_argument(
|
||
"--capture_curve_max_pct",
|
||
type=int,
|
||
default=50,
|
||
help="If >0, also export a dense capture curve for k=1..max_pct",
|
||
)
|
||
args = ap.parse_args()
|
||
|
||
set_deterministic(args.seed)
|
||
|
||
specs = load_models_json(args.models_json)
|
||
if not specs:
|
||
raise ValueError("No models provided")
|
||
|
||
export_dir = str(args.export_dir)
|
||
_ensure_dir(export_dir)
|
||
cause_names = load_cause_names("labels.csv")
|
||
|
||
# Determine top-K causes from the evaluation split only (model-agnostic),
|
||
# aligned to time-dependent risk: occurrence within tau_max after context.
|
||
first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path)
|
||
cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [
|
||
"bmi", "smoking", "alcohol"]
|
||
dataset_for_top = HealthDataset(
|
||
data_prefix=args.data_prefix, covariate_list=cov_list)
|
||
subset_for_top = build_eval_subset(
|
||
dataset_for_top,
|
||
train_ratio=float(first_cfg.get("train_ratio", 0.7)),
|
||
val_ratio=float(first_cfg.get("val_ratio", 0.15)),
|
||
seed=int(first_cfg.get("random_seed", 42)),
|
||
split=args.split,
|
||
)
|
||
loader_top = DataLoader(
|
||
subset_for_top,
|
||
batch_size=args.batch_size,
|
||
shuffle=False,
|
||
num_workers=args.num_workers,
|
||
collate_fn=health_collate_fn,
|
||
)
|
||
|
||
tau_max = float(max(args.eval_horizons))
|
||
counts, n_total_eval = count_occurs_within_horizon(
|
||
loader=loader_top,
|
||
offset_years=args.offset_years,
|
||
tau_years=tau_max,
|
||
n_disease=dataset_for_top.n_disease,
|
||
device=args.device,
|
||
)
|
||
|
||
focus_causes = pick_focus_causes(
|
||
counts_within_tau=counts,
|
||
n_disease=int(dataset_for_top.n_disease),
|
||
death_cause_id=int(args.death_cause_id),
|
||
k=int(args.focus_k),
|
||
)
|
||
top_cause_ids = np.asarray(focus_causes, dtype=int)
|
||
|
||
# Export the chosen focus causes.
|
||
focus_rows: List[Dict[str, Any]] = []
|
||
for r, cid in enumerate(focus_causes, start=1):
|
||
row: Dict[str, Any] = {"cause": int(cid), "rank": int(r)}
|
||
if cid in cause_names:
|
||
row["cause_name"] = cause_names[cid]
|
||
focus_rows.append(row)
|
||
focus_fieldnames = ["cause", "rank"] + \
|
||
(["cause_name"] if any("cause_name" in r for r in focus_rows) else [])
|
||
write_simple_csv(os.path.join(export_dir, "focus_causes.csv"),
|
||
focus_fieldnames, focus_rows)
|
||
|
||
# Metadata for focus causes (within tau_max).
|
||
top_causes_meta: List[Dict[str, Any]] = []
|
||
for cid in focus_causes:
|
||
n_case = int(counts[int(cid)]) if int(cid) < int(counts.shape[0]) else 0
|
||
top_causes_meta.append(
|
||
{
|
||
"cause_id": int(cid),
|
||
"tau_years": float(tau_max),
|
||
"n_case_within_tau": n_case,
|
||
"n_control_within_tau": int(n_total_eval - n_case),
|
||
"n_total_eval": int(n_total_eval),
|
||
}
|
||
)
|
||
|
||
# Horizon groups for Experiment 3
|
||
hg_rows, horizon_to_group, hg_method = make_horizon_groups(
|
||
args.eval_horizons)
|
||
write_simple_csv(
|
||
os.path.join(export_dir, "horizon_groups.csv"),
|
||
["horizon", "group", "group_rank"],
|
||
hg_rows,
|
||
)
|
||
|
||
rows: List[Dict[str, Any]] = []
|
||
calib_rows: List[Dict[str, Any]] = []
|
||
|
||
# Experiment exports (accumulated across models)
|
||
rs_bins_rows: List[Dict[str, Any]] = []
|
||
rs_sum_rows: List[Dict[str, Any]] = []
|
||
cap_points_rows: List[Dict[str, Any]] = []
|
||
cap_curve_rows: List[Dict[str, Any]] = []
|
||
cal_group_sum_rows: List[Dict[str, Any]] = []
|
||
cal_group_bins_rows: List[Dict[str, Any]] = []
|
||
|
||
# Track per-model integrity status for meta JSON.
|
||
integrity_meta: Dict[str, Any] = {}
|
||
|
||
# Evaluate each model
|
||
for spec in specs:
|
||
run_dir = os.path.dirname(os.path.abspath(spec.checkpoint_path))
|
||
tag = _make_eval_tag(args.split, float(args.offset_years))
|
||
|
||
# Remember list offsets so we can write per-model slices to the model's run_dir.
|
||
rows_start = len(rows)
|
||
calib_start = len(calib_rows)
|
||
|
||
cfg = load_train_config_for_checkpoint(spec.checkpoint_path)
|
||
|
||
# Identifiers for consistent exports
|
||
model_id = str(spec.name)
|
||
model_type = str(
|
||
cfg.get("model_type", spec.model_type if hasattr(spec, "model_type") else ""))
|
||
loss_type_id = str(
|
||
cfg.get("loss_type", spec.loss_type if hasattr(spec, "loss_type") else ""))
|
||
age_encoder = str(cfg.get("age_encoder", ""))
|
||
cov_type = "full" if _parse_bool(
|
||
cfg.get("full_cov", False)) else "partial"
|
||
|
||
cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [
|
||
"bmi", "smoking", "alcohol"]
|
||
dataset = HealthDataset(
|
||
data_prefix=args.data_prefix, covariate_list=cov_list)
|
||
subset = build_eval_subset(
|
||
dataset,
|
||
train_ratio=float(cfg.get("train_ratio", 0.7)),
|
||
val_ratio=float(cfg.get("val_ratio", 0.15)),
|
||
seed=int(cfg.get("random_seed", 42)),
|
||
split=args.split,
|
||
)
|
||
loader = DataLoader(
|
||
subset,
|
||
batch_size=args.batch_size,
|
||
shuffle=False,
|
||
num_workers=args.num_workers,
|
||
collate_fn=health_collate_fn,
|
||
)
|
||
|
||
backbone, head, loss_type, bin_edges = instantiate_model_and_head(
|
||
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
|
||
ckpt = torch.load(spec.checkpoint_path, map_location=args.device)
|
||
backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
|
||
head.load_state_dict(ckpt["head_state_dict"], strict=True)
|
||
|
||
(
|
||
cause_cif,
|
||
cif_full,
|
||
survival,
|
||
y_cause_within_tau,
|
||
sex,
|
||
) = predict_cifs_for_model(
|
||
backbone,
|
||
head,
|
||
loss_type,
|
||
bin_edges,
|
||
loader,
|
||
args.device,
|
||
args.offset_years,
|
||
args.eval_horizons,
|
||
top_cause_ids,
|
||
)
|
||
|
||
# CIF integrity checks before metrics.
|
||
integrity_ok, integrity_notes = check_cif_integrity(
|
||
cif_full,
|
||
args.eval_horizons,
|
||
tol=float(args.integrity_tol),
|
||
name=spec.name,
|
||
strict=bool(args.integrity_strict),
|
||
survival=survival,
|
||
)
|
||
integrity_meta[spec.name] = {
|
||
"integrity_ok": bool(integrity_ok),
|
||
"integrity_notes": integrity_notes,
|
||
}
|
||
|
||
evaluate_one_model(
|
||
model_name=spec.name,
|
||
cause_cif=cause_cif,
|
||
y_cause_within_tau=y_cause_within_tau,
|
||
eval_horizons=args.eval_horizons,
|
||
top_cause_ids=top_cause_ids,
|
||
out_rows=rows,
|
||
calib_rows=calib_rows,
|
||
auc_ci_method=str(args.auc_ci_method),
|
||
bootstrap_n=int(args.bootstrap_n),
|
||
)
|
||
|
||
# ============================================================
|
||
# Experiment 1: Risk stratification bins + summary
|
||
# ============================================================
|
||
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
|
||
for h_i, tau in enumerate(args.eval_horizons):
|
||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||
p = cause_cif[:, j, h_i]
|
||
y = y_cause_within_tau[:, j, h_i]
|
||
if sex_mask is not None:
|
||
p = p[sex_mask]
|
||
y = y[sex_mask]
|
||
q_used, bin_rows, summary = compute_risk_stratification_bins(
|
||
p, y, q_default=10)
|
||
for br in bin_rows:
|
||
rs_bins_rows.append(
|
||
{
|
||
"model_id": model_id,
|
||
"model_type": model_type,
|
||
"loss_type": loss_type_id,
|
||
"age_encoder": age_encoder,
|
||
"cov_type": cov_type,
|
||
"cause": int(cause_id),
|
||
"horizon": float(tau),
|
||
"sex": sex_label,
|
||
"q": int(br["q"]),
|
||
"n_bin": int(br["n_bin"]),
|
||
"p_mean": _safe_float(br["p_mean"]),
|
||
"y_rate": _safe_float(br["y_rate"]),
|
||
"y_overall": _safe_float(br["y_overall"]),
|
||
"lift_vs_overall": _safe_float(br["lift_vs_overall"]),
|
||
"q_total": int(q_used),
|
||
}
|
||
)
|
||
rs_sum_rows.append(
|
||
{
|
||
"model_id": model_id,
|
||
"model_type": model_type,
|
||
"loss_type": loss_type_id,
|
||
"age_encoder": age_encoder,
|
||
"cov_type": cov_type,
|
||
"cause": int(cause_id),
|
||
"horizon": float(tau),
|
||
"sex": sex_label,
|
||
"q_total": int(q_used),
|
||
"top_decile_y_rate": _safe_float(summary["top_decile_y_rate"]),
|
||
"bottom_half_y_rate": _safe_float(summary["bottom_half_y_rate"]),
|
||
"lift_top10_vs_bottom50": _safe_float(summary["lift_top10_vs_bottom50"]),
|
||
"slope_pred_vs_obs": _safe_float(summary["slope_pred_vs_obs"]),
|
||
}
|
||
)
|
||
|
||
# ============================================================
|
||
# Experiment 2: High-risk capture points (+ optional curve)
|
||
# ============================================================
|
||
k_pcts = [int(x) for x in args.capture_k_pcts]
|
||
curve_max = int(args.capture_curve_max_pct)
|
||
curve_grid = list(range(1, curve_max + 1)
|
||
) if curve_max and curve_max > 0 else []
|
||
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
|
||
for h_i, tau in enumerate(args.eval_horizons):
|
||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||
p = cause_cif[:, j, h_i]
|
||
y = y_cause_within_tau[:, j, h_i]
|
||
if sex_mask is not None:
|
||
p = p[sex_mask]
|
||
y = y[sex_mask]
|
||
|
||
for r in compute_capture_points(p, y, k_pcts):
|
||
cap_points_rows.append(
|
||
{
|
||
"model_id": model_id,
|
||
"model_type": model_type,
|
||
"loss_type": loss_type_id,
|
||
"age_encoder": age_encoder,
|
||
"cov_type": cov_type,
|
||
"cause": int(cause_id),
|
||
"horizon": float(tau),
|
||
"sex": sex_label,
|
||
**r,
|
||
}
|
||
)
|
||
if curve_grid:
|
||
for r in compute_capture_points(p, y, curve_grid):
|
||
cap_curve_rows.append(
|
||
{
|
||
"model_id": model_id,
|
||
"model_type": model_type,
|
||
"loss_type": loss_type_id,
|
||
"age_encoder": age_encoder,
|
||
"cov_type": cov_type,
|
||
"cause": int(cause_id),
|
||
"horizon": float(tau),
|
||
"sex": sex_label,
|
||
**r,
|
||
}
|
||
)
|
||
|
||
# ============================================================
|
||
# Experiment 3: Short/Medium/Long horizon-group calibration
|
||
# ============================================================
|
||
# Per-horizon metrics for grouping
|
||
# Build a dict for quick access: (cause_id, horizon) -> (brier, ici)
|
||
per_h: Dict[Tuple[int, float], Dict[str, float]] = {}
|
||
for rr in rows[rows_start:]:
|
||
if rr.get("model_name") != spec.name:
|
||
continue
|
||
if rr.get("metric_name") not in {"cause_brier", "cause_ici"}:
|
||
continue
|
||
try:
|
||
cid = int(rr.get("cause"))
|
||
except Exception:
|
||
continue
|
||
h = _safe_float(rr.get("horizon"))
|
||
if not np.isfinite(h):
|
||
continue
|
||
key = (cid, float(h))
|
||
d = per_h.get(key, {})
|
||
d[str(rr.get("metric_name"))] = _safe_float(rr.get("value"))
|
||
per_h[key] = d
|
||
|
||
# Compute group summaries and pooled bins using the same quantile bins as exp1 (per slice).
|
||
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
|
||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||
# Decide Q per slice for pooled reliability curve
|
||
n_slice = int(np.sum(sex_mask)) if sex_mask is not None else int(
|
||
sex.shape[0])
|
||
q_pool = 10 if n_slice >= 200 else 5
|
||
|
||
# Collect per-horizon brier/ici values
|
||
group_vals: Dict[str, Dict[str, List[float]]] = {"short": {"brier": [], "ici": [
|
||
]}, "medium": {"brier": [], "ici": []}, "long": {"brier": [], "ici": []}}
|
||
group_n_total: Dict[str, int] = {
|
||
"short": 0, "medium": 0, "long": 0}
|
||
|
||
# Pooled bins: group -> q -> accumulators
|
||
pooled: Dict[str, Dict[int, Dict[str, float]]] = {
|
||
"short": {}, "medium": {}, "long": {}}
|
||
|
||
for h_i, tau in enumerate(args.eval_horizons):
|
||
g = horizon_to_group.get(float(tau), "long")
|
||
|
||
# brier/ici per horizon (already computed at full-sample level)
|
||
d = per_h.get((int(cause_id), float(tau)), {})
|
||
brier_h = _safe_float(d.get("cause_brier"))
|
||
ici_h = _safe_float(d.get("cause_ici"))
|
||
if np.isfinite(brier_h):
|
||
group_vals[g]["brier"].append(brier_h)
|
||
if np.isfinite(ici_h):
|
||
group_vals[g]["ici"].append(ici_h)
|
||
|
||
# pooled reliability bins from raw p/y
|
||
p = cause_cif[:, j, h_i]
|
||
y = y_cause_within_tau[:, j, h_i]
|
||
if sex_mask is not None:
|
||
p = p[sex_mask]
|
||
y = y[sex_mask]
|
||
if p.size == 0:
|
||
continue
|
||
edges = _quantile_edges(p, q_pool)
|
||
for qi in range(q_pool):
|
||
m = (p > edges[qi]) & (p <= edges[qi + 1])
|
||
nb = int(np.sum(m))
|
||
if nb == 0:
|
||
continue
|
||
pm = float(np.mean(p[m]))
|
||
yr = float(np.mean(y[m]))
|
||
acc = pooled[g].get(
|
||
qi + 1, {"n": 0.0, "p_sum": 0.0, "y_sum": 0.0})
|
||
acc["n"] += float(nb)
|
||
acc["p_sum"] += float(nb) * pm
|
||
acc["y_sum"] += float(nb) * yr
|
||
pooled[g][qi + 1] = acc
|
||
group_n_total[g] = max(group_n_total[g], int(p.size))
|
||
|
||
for g in ["short", "medium", "long"]:
|
||
bvals = group_vals[g]["brier"]
|
||
ivals = group_vals[g]["ici"]
|
||
cal_group_sum_rows.append(
|
||
{
|
||
"model_id": model_id,
|
||
"model_type": model_type,
|
||
"loss_type": loss_type_id,
|
||
"age_encoder": age_encoder,
|
||
"cov_type": cov_type,
|
||
"cause": int(cause_id),
|
||
"sex": sex_label,
|
||
"horizon_group": g,
|
||
"brier_mean": float(np.mean(bvals)) if bvals else float("nan"),
|
||
"brier_median": float(np.median(bvals)) if bvals else float("nan"),
|
||
"ici_mean": float(np.mean(ivals)) if ivals else float("nan"),
|
||
"ici_median": float(np.median(ivals)) if ivals else float("nan"),
|
||
"n_total": int(group_n_total[g]),
|
||
"horizon_grouping_method": hg_method,
|
||
}
|
||
)
|
||
|
||
for qi in range(1, q_pool + 1):
|
||
acc = pooled[g].get(qi)
|
||
if not acc or float(acc.get("n", 0.0)) <= 0:
|
||
continue
|
||
n_bin = float(acc["n"])
|
||
cal_group_bins_rows.append(
|
||
{
|
||
"model_id": model_id,
|
||
"model_type": model_type,
|
||
"loss_type": loss_type_id,
|
||
"age_encoder": age_encoder,
|
||
"cov_type": cov_type,
|
||
"cause": int(cause_id),
|
||
"sex": sex_label,
|
||
"horizon_group": g,
|
||
"q": int(qi),
|
||
"n_bin": int(n_bin),
|
||
"p_mean": float(acc["p_sum"] / n_bin),
|
||
"y_rate": float(acc["y_sum"] / n_bin),
|
||
"q_total": int(q_pool),
|
||
"horizon_grouping_method": hg_method,
|
||
}
|
||
)
|
||
|
||
# Optionally write top-cause counts into the main results CSV as metric rows.
|
||
for tc in top_causes_meta:
|
||
rows.append(
|
||
{
|
||
"model_name": spec.name,
|
||
"metric_name": "topcause_n_case_within_tau",
|
||
"horizon": float(tc["tau_years"]),
|
||
"cause": int(tc["cause_id"]),
|
||
"value": int(tc["n_case_within_tau"]),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
rows.append(
|
||
{
|
||
"model_name": spec.name,
|
||
"metric_name": "topcause_n_control_within_tau",
|
||
"horizon": float(tc["tau_years"]),
|
||
"cause": int(tc["cause_id"]),
|
||
"value": int(tc["n_control_within_tau"]),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
rows.append(
|
||
{
|
||
"model_name": spec.name,
|
||
"metric_name": "topcause_n_total_eval",
|
||
"horizon": float(tc["tau_years"]),
|
||
"cause": int(tc["cause_id"]),
|
||
"value": int(tc["n_total_eval"]),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
|
||
# Write per-model results into the model's run directory.
|
||
model_rows = rows[rows_start:]
|
||
model_calib_rows = calib_rows[calib_start:]
|
||
model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv")
|
||
model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv")
|
||
model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json")
|
||
|
||
write_results_csv(model_out_csv, model_rows)
|
||
write_calibration_bins_csv(model_calib_csv, model_calib_rows)
|
||
|
||
model_meta = {
|
||
"model_name": spec.name,
|
||
"checkpoint_path": spec.checkpoint_path,
|
||
"run_dir": run_dir,
|
||
"split": args.split,
|
||
"offset_years": args.offset_years,
|
||
"eval_horizons": [float(x) for x in args.eval_horizons],
|
||
"top_k_causes": int(args.top_k_causes),
|
||
"top_cause_ids": top_cause_ids.tolist(),
|
||
"top_causes": top_causes_meta,
|
||
"integrity": {spec.name: integrity_meta.get(spec.name, {})},
|
||
"paths": {
|
||
"results_csv": model_out_csv,
|
||
"calibration_bins_csv": model_calib_csv,
|
||
},
|
||
}
|
||
with open(model_meta_json, "w") as f:
|
||
json.dump(model_meta, f, indent=2)
|
||
|
||
print(f"Wrote per-model results to {model_out_csv}")
|
||
|
||
write_results_csv(args.out_csv, rows)
|
||
|
||
# Write calibration curve points to a separate CSV.
|
||
out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "."
|
||
calib_csv_path = os.path.join(out_dir, "calibration_bins.csv")
|
||
write_calibration_bins_csv(calib_csv_path, calib_rows)
|
||
|
||
# Write experiment exports
|
||
write_simple_csv(
|
||
os.path.join(export_dir, "risk_stratification_bins.csv"),
|
||
[
|
||
"model_id",
|
||
"model_type",
|
||
"loss_type",
|
||
"age_encoder",
|
||
"cov_type",
|
||
"cause",
|
||
"horizon",
|
||
"sex",
|
||
"q",
|
||
"n_bin",
|
||
"p_mean",
|
||
"y_rate",
|
||
"y_overall",
|
||
"lift_vs_overall",
|
||
"q_total",
|
||
],
|
||
rs_bins_rows,
|
||
)
|
||
write_simple_csv(
|
||
os.path.join(export_dir, "risk_stratification_summary.csv"),
|
||
[
|
||
"model_id",
|
||
"model_type",
|
||
"loss_type",
|
||
"age_encoder",
|
||
"cov_type",
|
||
"cause",
|
||
"horizon",
|
||
"sex",
|
||
"q_total",
|
||
"top_decile_y_rate",
|
||
"bottom_half_y_rate",
|
||
"lift_top10_vs_bottom50",
|
||
"slope_pred_vs_obs",
|
||
],
|
||
rs_sum_rows,
|
||
)
|
||
write_simple_csv(
|
||
os.path.join(export_dir, "lift_capture_points.csv"),
|
||
[
|
||
"model_id",
|
||
"model_type",
|
||
"loss_type",
|
||
"age_encoder",
|
||
"cov_type",
|
||
"cause",
|
||
"horizon",
|
||
"sex",
|
||
"k_pct",
|
||
"n_targeted",
|
||
"events_targeted",
|
||
"events_total",
|
||
"event_capture_rate",
|
||
"precision_in_targeted",
|
||
],
|
||
cap_points_rows,
|
||
)
|
||
if cap_curve_rows:
|
||
write_simple_csv(
|
||
os.path.join(export_dir, "lift_capture_curve.csv"),
|
||
[
|
||
"model_id",
|
||
"model_type",
|
||
"loss_type",
|
||
"age_encoder",
|
||
"cov_type",
|
||
"cause",
|
||
"horizon",
|
||
"sex",
|
||
"k_pct",
|
||
"n_targeted",
|
||
"events_targeted",
|
||
"events_total",
|
||
"event_capture_rate",
|
||
"precision_in_targeted",
|
||
],
|
||
cap_curve_rows,
|
||
)
|
||
write_simple_csv(
|
||
os.path.join(export_dir, "calibration_groups_summary.csv"),
|
||
[
|
||
"model_id",
|
||
"model_type",
|
||
"loss_type",
|
||
"age_encoder",
|
||
"cov_type",
|
||
"cause",
|
||
"sex",
|
||
"horizon_group",
|
||
"brier_mean",
|
||
"brier_median",
|
||
"ici_mean",
|
||
"ici_median",
|
||
"n_total",
|
||
"horizon_grouping_method",
|
||
],
|
||
cal_group_sum_rows,
|
||
)
|
||
write_simple_csv(
|
||
os.path.join(export_dir, "calibration_groups_bins.csv"),
|
||
[
|
||
"model_id",
|
||
"model_type",
|
||
"loss_type",
|
||
"age_encoder",
|
||
"cov_type",
|
||
"cause",
|
||
"sex",
|
||
"horizon_group",
|
||
"q",
|
||
"n_bin",
|
||
"p_mean",
|
||
"y_rate",
|
||
"q_total",
|
||
"horizon_grouping_method",
|
||
],
|
||
cal_group_bins_rows,
|
||
)
|
||
|
||
# Manifest markdown (stable, user-facing)
|
||
manifest_path = os.path.join(export_dir, "eval_exports_manifest.md")
|
||
with open(manifest_path, "w", encoding="utf-8") as f:
|
||
f.write(
|
||
"# Evaluation Exports Manifest\n\n"
|
||
"This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). "
|
||
"All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n"
|
||
"## Files\n\n"
|
||
"- focus_causes.csv: The deterministically selected focus causes (Death + top-K). Intended plot: bar of event support + label table.\n"
|
||
"- horizon_groups.csv: Mapping from each horizon to short/medium/long buckets. Intended plot: annotate calibration comparisons.\n"
|
||
"- risk_stratification_bins.csv: Quantile bins (deciles or quintiles) with predicted vs observed event rates and lift. Intended plot: reliability-by-risk-tier lines.\n"
|
||
"- risk_stratification_summary.csv: Compact stratification summaries (top decile vs bottom half lift, slope). Intended plot: slide-friendly comparison table.\n"
|
||
"- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n"
|
||
"- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n"
|
||
"- calibration_groups_summary.csv: Short/medium/long aggregated Brier/ICI (mean/median). Intended plot: grouped bar chart by horizon bucket.\n"
|
||
"- calibration_groups_bins.csv: Pooled reliability points per horizon bucket (weighted by bin size). Intended plot: 3-panel reliability curves per model.\n"
|
||
)
|
||
|
||
meta = {
|
||
"split": args.split,
|
||
"offset_years": args.offset_years,
|
||
"eval_horizons": [float(x) for x in args.eval_horizons],
|
||
"tau_max": float(tau_max),
|
||
"top_k_causes": int(args.top_k_causes),
|
||
"top_cause_ids": top_cause_ids.tolist(),
|
||
"top_causes": top_causes_meta,
|
||
"integrity": integrity_meta,
|
||
"notes": {
|
||
"label": "Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau])",
|
||
"primary_metrics": "cause_brier (CIF-based) and cause_ici (calibration)",
|
||
"secondary_metrics": "cause_auc (discrimination) with optional CI",
|
||
"exclusions": "No all-cause aggregation; no next-event formulation; ECE not reported",
|
||
"warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.",
|
||
"exports_dir": export_dir,
|
||
"focus_causes": focus_causes,
|
||
"horizon_grouping_method": hg_method,
|
||
},
|
||
}
|
||
with open(args.out_meta_json, "w") as f:
|
||
json.dump(meta, f, indent=2)
|
||
|
||
print(f"Wrote {args.out_csv} with {len(rows)} rows")
|
||
print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows")
|
||
print(f"Wrote {args.out_meta_json}")
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|