1555 lines
53 KiB
Python
1555 lines
53 KiB
Python
import argparse
|
||
import csv
|
||
import json
|
||
import math
|
||
import os
|
||
import random
|
||
import statistics
|
||
from dataclasses import dataclass
|
||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from torch.utils.data import DataLoader, random_split
|
||
|
||
from dataset import HealthDataset, health_collate_fn
|
||
from losses import DiscreteTimeCIFNLLLoss
|
||
from model import DelphiFork, SapDelphi, SimpleHead
|
||
|
||
|
||
# ============================================================
|
||
# Constants / defaults (aligned with evaluate_prompt.md)
|
||
# ============================================================
|
||
DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")]
|
||
DEFAULT_EVAL_HORIZONS = [0.72, 1.61, 3.84, 10.0]
|
||
DAYS_PER_YEAR = 365.25
|
||
|
||
|
||
# ============================================================
|
||
# Model specs
|
||
# ============================================================
|
||
@dataclass(frozen=True)
|
||
class ModelSpec:
|
||
name: str
|
||
model_type: str # delphi_fork | sap_delphi
|
||
loss_type: str # exponential | discrete_time_cif
|
||
full_cov: bool
|
||
checkpoint_path: str
|
||
|
||
|
||
# ============================================================
|
||
# Determinism
|
||
# ============================================================
|
||
|
||
def set_deterministic(seed: int) -> None:
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
torch.cuda.manual_seed_all(seed)
|
||
torch.backends.cudnn.deterministic = True
|
||
torch.backends.cudnn.benchmark = False
|
||
|
||
|
||
# ============================================================
|
||
# Utilities
|
||
# ============================================================
|
||
|
||
def _parse_bool(x: Any) -> bool:
|
||
if isinstance(x, bool):
|
||
return x
|
||
s = str(x).strip().lower()
|
||
if s in {"true", "1", "yes", "y"}:
|
||
return True
|
||
if s in {"false", "0", "no", "n"}:
|
||
return False
|
||
raise ValueError(f"Cannot parse boolean: {x!r}")
|
||
|
||
|
||
def load_models_json(path: str) -> List[ModelSpec]:
|
||
with open(path, "r") as f:
|
||
data = json.load(f)
|
||
if not isinstance(data, list):
|
||
raise ValueError("models_json must be a list of model entries")
|
||
|
||
specs: List[ModelSpec] = []
|
||
for row in data:
|
||
specs.append(
|
||
ModelSpec(
|
||
name=str(row["name"]),
|
||
model_type=str(row["model_type"]),
|
||
loss_type=str(row["loss_type"]),
|
||
full_cov=_parse_bool(row["full_cov"]),
|
||
checkpoint_path=str(row["checkpoint_path"]),
|
||
)
|
||
)
|
||
return specs
|
||
|
||
|
||
def load_train_config_for_checkpoint(checkpoint_path: str) -> Dict[str, Any]:
|
||
run_dir = os.path.dirname(os.path.abspath(checkpoint_path))
|
||
cfg_path = os.path.join(run_dir, "train_config.json")
|
||
with open(cfg_path, "r") as f:
|
||
cfg = json.load(f)
|
||
return cfg
|
||
|
||
|
||
def build_eval_subset(
|
||
dataset: HealthDataset,
|
||
train_ratio: float,
|
||
val_ratio: float,
|
||
seed: int,
|
||
split: str,
|
||
):
|
||
n_total = len(dataset)
|
||
n_train = int(n_total * train_ratio)
|
||
n_val = int(n_total * val_ratio)
|
||
n_test = n_total - n_train - n_val
|
||
|
||
train_ds, val_ds, test_ds = random_split(
|
||
dataset,
|
||
[n_train, n_val, n_test],
|
||
generator=torch.Generator().manual_seed(seed),
|
||
)
|
||
|
||
if split == "train":
|
||
return train_ds
|
||
if split == "val":
|
||
return val_ds
|
||
if split == "test":
|
||
return test_ds
|
||
if split == "all":
|
||
return dataset
|
||
raise ValueError("split must be one of: train, val, test, all")
|
||
|
||
|
||
# ============================================================
|
||
# Context selection (anti-leakage)
|
||
# ============================================================
|
||
|
||
def select_context_indices(
|
||
event_seq: torch.Tensor,
|
||
time_seq: torch.Tensor,
|
||
offset_years: float,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""Select per-sample prediction context index.
|
||
|
||
IMPORTANT SEMANTICS:
|
||
- The last observed token time is treated as the FOLLOW-UP END time.
|
||
- We pick the last valid token with time <= (followup_end_time - offset).
|
||
- We do NOT interpret followup_end_time as an event time.
|
||
|
||
Returns:
|
||
keep_mask: (B,) bool, which samples have a valid context
|
||
t_ctx: (B,) long, index into sequence
|
||
t_ctx_time: (B,) float, time (days) at context
|
||
"""
|
||
# valid tokens are event != 0 (padding is 0)
|
||
valid = event_seq != 0
|
||
lengths = valid.sum(dim=1)
|
||
last_idx = torch.clamp(lengths - 1, min=0)
|
||
|
||
b = torch.arange(event_seq.size(0), device=event_seq.device)
|
||
followup_end_time = time_seq[b, last_idx]
|
||
t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR)
|
||
|
||
eligible = valid & (time_seq <= t_cut.unsqueeze(1))
|
||
eligible_counts = eligible.sum(dim=1)
|
||
keep = eligible_counts > 0
|
||
|
||
t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long)
|
||
t_ctx_time = time_seq[b, t_ctx]
|
||
|
||
return keep, t_ctx, t_ctx_time
|
||
|
||
|
||
def next_event_after_context(
|
||
event_seq: torch.Tensor,
|
||
time_seq: torch.Tensor,
|
||
t_ctx: torch.Tensor,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""Return next disease event after context.
|
||
|
||
Returns:
|
||
dt_years: (B,) float, time to next disease in years; +inf if none
|
||
cause: (B,) long, disease id in [0,K) for next event; -1 if none
|
||
"""
|
||
B, L = event_seq.shape
|
||
b = torch.arange(B, device=event_seq.device)
|
||
t0 = time_seq[b, t_ctx]
|
||
|
||
# Allow same-day events while excluding the context token itself.
|
||
# We rely on time-sorted sequences and select the FIRST valid future event by index.
|
||
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
||
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
|
||
idx_min = torch.where(
|
||
future, idxs, torch.full_like(idxs, L)).min(dim=1).values
|
||
|
||
has = idx_min < L
|
||
t_next = torch.where(has, idx_min, torch.zeros_like(idx_min))
|
||
|
||
t_next_time = time_seq[b, t_next]
|
||
dt_days = t_next_time - t0
|
||
dt_years = dt_days / DAYS_PER_YEAR
|
||
dt_years = torch.where(
|
||
has, dt_years, torch.full_like(dt_years, float("inf")))
|
||
|
||
cause_token = event_seq[b, t_next]
|
||
cause = (cause_token - 2).to(torch.long)
|
||
cause = torch.where(has, cause, torch.full_like(cause, -1))
|
||
|
||
return dt_years, cause
|
||
|
||
|
||
def multi_hot_ever_within_horizon(
|
||
event_seq: torch.Tensor,
|
||
time_seq: torch.Tensor,
|
||
t_ctx: torch.Tensor,
|
||
tau_years: float,
|
||
n_disease: int,
|
||
) -> torch.Tensor:
|
||
"""Binary labels: disease k occurs within tau after context (any occurrence)."""
|
||
B, L = event_seq.shape
|
||
b = torch.arange(B, device=event_seq.device)
|
||
t0 = time_seq[b, t_ctx]
|
||
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
||
|
||
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
||
# Include same-day events after context, exclude any token at/before context index.
|
||
in_window = (
|
||
(idxs > t_ctx.unsqueeze(1))
|
||
& (time_seq >= t0.unsqueeze(1))
|
||
& (time_seq <= t1.unsqueeze(1))
|
||
& (event_seq >= 2)
|
||
& (event_seq != 0)
|
||
)
|
||
|
||
if not in_window.any():
|
||
return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||
|
||
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||
|
||
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||
y[b_idx, disease_ids] = True
|
||
return y
|
||
|
||
|
||
def multi_hot_ever_after_context_anytime(
|
||
event_seq: torch.Tensor,
|
||
t_ctx: torch.Tensor,
|
||
n_disease: int,
|
||
) -> torch.Tensor:
|
||
"""Binary labels: disease k occurs ANYTIME after the prediction context.
|
||
|
||
This is Delphi2M-compatible for Task A case/control definition.
|
||
Same-day events are included as long as they occur after the context token index.
|
||
"""
|
||
B, L = event_seq.shape
|
||
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
||
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
|
||
|
||
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||
if not future.any():
|
||
return y
|
||
|
||
b_idx, t_idx = future.nonzero(as_tuple=True)
|
||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||
y[b_idx, disease_ids] = True
|
||
return y
|
||
|
||
|
||
def multi_hot_selected_causes_within_horizon(
|
||
event_seq: torch.Tensor,
|
||
time_seq: torch.Tensor,
|
||
t_ctx: torch.Tensor,
|
||
tau_years: float,
|
||
cause_ids: torch.Tensor,
|
||
n_disease: int,
|
||
) -> torch.Tensor:
|
||
"""Labels for selected causes only: does cause k occur within tau after context?"""
|
||
B, L = event_seq.shape
|
||
device = event_seq.device
|
||
b = torch.arange(B, device=device)
|
||
t0 = time_seq[b, t_ctx]
|
||
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
||
|
||
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
||
in_window = (
|
||
(idxs > t_ctx.unsqueeze(1))
|
||
& (time_seq >= t0.unsqueeze(1))
|
||
& (time_seq <= t1.unsqueeze(1))
|
||
& (event_seq >= 2)
|
||
& (event_seq != 0)
|
||
)
|
||
|
||
out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device)
|
||
if not in_window.any():
|
||
return out
|
||
|
||
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||
|
||
# Filter to selected causes via a boolean membership mask over the global disease space.
|
||
selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device)
|
||
selected[cause_ids] = True
|
||
keep = selected[disease_ids]
|
||
if not keep.any():
|
||
return out
|
||
|
||
b_idx = b_idx[keep]
|
||
disease_ids = disease_ids[keep]
|
||
|
||
# Map disease_id -> local index in cause_ids
|
||
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
|
||
lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device)
|
||
lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device)
|
||
local = lookup[disease_ids]
|
||
out[b_idx, local] = True
|
||
return out
|
||
|
||
|
||
# ============================================================
|
||
# CIF conversion
|
||
# ============================================================
|
||
|
||
def cifs_from_exponential_logits(
|
||
logits: torch.Tensor,
|
||
taus: Sequence[float],
|
||
eps: float = 1e-6,
|
||
return_survival: bool = False,
|
||
) -> torch.Tensor:
|
||
"""Convert exponential cause-specific logits -> CIFs at taus.
|
||
|
||
logits: (B, K)
|
||
returns: (B, K, H) or (cif, survival) if return_survival
|
||
"""
|
||
hazards = F.softplus(logits) + eps
|
||
total = hazards.sum(dim=1, keepdim=True) # (B,1)
|
||
|
||
taus_t = torch.tensor(list(taus), device=logits.device,
|
||
dtype=hazards.dtype).view(1, 1, -1)
|
||
total_h = total.unsqueeze(-1) # (B,1,1)
|
||
|
||
# (1 - exp(-Lambda * tau))
|
||
one_minus_surv = 1.0 - torch.exp(-total_h * taus_t)
|
||
frac = hazards / torch.clamp(total, min=eps)
|
||
cif = frac.unsqueeze(-1) * one_minus_surv # (B,K,H)
|
||
# If total==0, set to 0
|
||
cif = torch.where(total_h > 0, cif, torch.zeros_like(cif))
|
||
|
||
if not return_survival:
|
||
return cif
|
||
|
||
survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H)
|
||
# Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors.
|
||
nonzero = (total.squeeze(1) > 0).unsqueeze(1)
|
||
survival = torch.where(nonzero, survival, torch.ones_like(survival))
|
||
return cif, survival
|
||
|
||
|
||
def cifs_from_discrete_time_logits(
|
||
logits: torch.Tensor,
|
||
bin_edges: Sequence[float],
|
||
taus: Sequence[float],
|
||
return_survival: bool = False,
|
||
) -> torch.Tensor:
|
||
"""Convert discrete-time CIF logits -> CIFs at taus.
|
||
|
||
logits: (B, K+1, n_bins+1)
|
||
bin_edges: len=n_bins+1 (including 0 and inf)
|
||
taus: subset of finite bin edges (recommended)
|
||
|
||
returns: (B, K, H) or (cif, survival) if return_survival
|
||
"""
|
||
if logits.ndim != 3:
|
||
raise ValueError("Expected logits shape (B, K+1, n_bins+1)")
|
||
|
||
B, K_plus_1, n_bins_plus_1 = logits.shape
|
||
K = K_plus_1 - 1
|
||
|
||
edges = [float(x) for x in bin_edges]
|
||
# drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf
|
||
finite_edges = [e for e in edges[1:] if math.isfinite(e)]
|
||
n_bins = len(finite_edges)
|
||
|
||
if n_bins_plus_1 != len(edges):
|
||
raise ValueError("logits last dim must match len(bin_edges)")
|
||
|
||
probs = torch.softmax(logits, dim=1) # (B, K+1, n_bins+1)
|
||
|
||
# use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot)
|
||
hazards = probs[:, :K, 1: 1 + n_bins] # (B,K,n_bins)
|
||
p_comp = probs[:, K, 1: 1 + n_bins] # (B,n_bins)
|
||
|
||
# survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v]
|
||
ones = torch.ones((B, 1), device=logits.device, dtype=probs.dtype)
|
||
cum = torch.cumprod(p_comp, dim=1)
|
||
s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # (B, n_bins)
|
||
|
||
cif_bins = torch.cumsum(s_prev.unsqueeze(
|
||
1) * hazards, dim=2) # (B,K,n_bins)
|
||
|
||
# Robust mapping from tau -> edge index (floating-point safe).
|
||
# taus are expected to align with bin edges, but may differ slightly due to parsing/serialization.
|
||
finite_edges_arr = np.asarray(finite_edges, dtype=float)
|
||
tau_to_idx: List[int] = []
|
||
for tau in taus:
|
||
tau_f = float(tau)
|
||
if not math.isfinite(tau_f):
|
||
raise ValueError("taus must be finite for discrete-time CIF")
|
||
diffs = np.abs(finite_edges_arr - tau_f)
|
||
j = int(np.argmin(diffs))
|
||
if diffs[j] > 1e-6:
|
||
raise ValueError(
|
||
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})"
|
||
)
|
||
tau_to_idx.append(j)
|
||
|
||
idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long)
|
||
cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H)
|
||
|
||
if not return_survival:
|
||
return cif
|
||
|
||
# Survival at each horizon = prod_{u <= idx[h]} p_comp[u]
|
||
survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v]
|
||
survival = survival_bins.index_select(dim=1, index=idx) # (B,H)
|
||
return cif, survival
|
||
|
||
|
||
# ============================================================
|
||
# CIF integrity checks
|
||
# ============================================================
|
||
|
||
def check_cif_integrity(
|
||
cause_cif: np.ndarray,
|
||
horizons: Sequence[float],
|
||
*,
|
||
tol: float = 1e-6,
|
||
name: str = "",
|
||
strict: bool = False,
|
||
survival: Optional[np.ndarray] = None,
|
||
) -> Tuple[bool, List[str]]:
|
||
"""Run basic sanity checks on CIF arrays.
|
||
|
||
Args:
|
||
cause_cif: (N, K, H)
|
||
horizons: length H
|
||
tol: tolerance for inequalities
|
||
name: model name for messages
|
||
strict: if True, raise ValueError on first failure
|
||
survival: optional (N, H) survival values at the same horizons
|
||
|
||
Returns:
|
||
(integrity_ok, notes)
|
||
"""
|
||
notes: List[str] = []
|
||
model_tag = f"[{name}] " if name else ""
|
||
|
||
def _fail(msg: str) -> None:
|
||
full = model_tag + msg
|
||
if strict:
|
||
raise ValueError(full)
|
||
print("WARNING:", full)
|
||
notes.append(msg)
|
||
|
||
cif = np.asarray(cause_cif)
|
||
if cif.ndim != 3:
|
||
_fail(f"integrity: expected cause_cif ndim=3, got {cif.ndim}")
|
||
return False, notes
|
||
N, K, H = cif.shape
|
||
if H != len(horizons):
|
||
_fail(
|
||
f"integrity: horizon length mismatch (H={H}, len(horizons)={len(horizons)})")
|
||
|
||
# (5) Finite
|
||
if not np.isfinite(cif).all():
|
||
_fail("integrity: non-finite values (NaN/Inf) in cause_cif")
|
||
|
||
# (1) Range
|
||
cmin = float(np.nanmin(cif))
|
||
cmax = float(np.nanmax(cif))
|
||
if cmin < -tol:
|
||
_fail(f"integrity: range min={cmin} < -tol={-tol}")
|
||
if cmax > 1.0 + tol:
|
||
_fail(f"integrity: range max={cmax} > 1+tol={1.0+tol}")
|
||
|
||
# (2) Monotonicity in horizons (per n,k)
|
||
diffs = np.diff(cif, axis=2)
|
||
if diffs.size > 0:
|
||
if np.nanmin(diffs) < -tol:
|
||
_fail("integrity: monotonicity violated (found negative diff along horizons)")
|
||
|
||
# (3) Probability mass: sum_k CIF <= 1
|
||
mass = np.sum(cif, axis=1) # (N,H)
|
||
mass_max = float(np.nanmax(mass))
|
||
if mass_max > 1.0 + tol:
|
||
_fail(f"integrity: probability mass exceeds 1 (max sum_k={mass_max})")
|
||
|
||
# (4) Conservation with survival, if provided
|
||
if survival is None:
|
||
warn = "integrity: survival not provided; skipping conservation check"
|
||
if strict:
|
||
# still skip (requested behavior), but keep message for context
|
||
notes.append(warn)
|
||
else:
|
||
print("WARNING:", model_tag + warn)
|
||
notes.append(warn)
|
||
else:
|
||
s = np.asarray(survival, dtype=float)
|
||
if s.shape != (N, H):
|
||
_fail(
|
||
f"integrity: survival shape mismatch (got {s.shape}, expected {(N, H)})")
|
||
else:
|
||
recon = 1.0 - s
|
||
err = np.abs(recon - mass)
|
||
# Discrete-time should be very tight; exponential may accumulate slightly more numerical error.
|
||
tol_cons = max(float(tol), 1e-4)
|
||
if float(np.nanmax(err)) > tol_cons:
|
||
_fail(
|
||
f"integrity: conservation violated (max |(1-surv)-sum_cif|={float(np.nanmax(err))}, tol={tol_cons})")
|
||
|
||
ok = len([n for n in notes if not n.endswith(
|
||
"skipping conservation check")]) == 0
|
||
return ok, notes
|
||
|
||
|
||
# ============================================================
|
||
# Metrics
|
||
# ============================================================
|
||
|
||
# --- Standard fast DeLong AUC variance + CI (ties handled via midranks) ---
|
||
|
||
def compute_midrank(x: np.ndarray) -> np.ndarray:
|
||
x = np.asarray(x, dtype=float)
|
||
order = np.argsort(x)
|
||
z = x[order]
|
||
n = x.shape[0]
|
||
t = np.zeros(n, dtype=float)
|
||
i = 0
|
||
while i < n:
|
||
j = i
|
||
while j < n and z[j] == z[i]:
|
||
j += 1
|
||
t[i:j] = 0.5 * (i + j - 1) + 1.0
|
||
i = j
|
||
out = np.empty(n, dtype=float)
|
||
out[order] = t
|
||
return out
|
||
|
||
|
||
def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""Fast DeLong method for computing AUC covariance.
|
||
|
||
predictions_sorted_transposed: shape (n_classifiers, n_examples) with positive examples first.
|
||
"""
|
||
preds = np.asarray(predictions_sorted_transposed, dtype=float)
|
||
m = int(label_1_count)
|
||
n = int(preds.shape[1] - m)
|
||
if m <= 0 or n <= 0:
|
||
return np.array([float("nan")]), np.array([[float("nan")]])
|
||
|
||
pos = preds[:, :m]
|
||
neg = preds[:, m:]
|
||
|
||
tx = np.array([compute_midrank(x) for x in pos])
|
||
ty = np.array([compute_midrank(x) for x in neg])
|
||
tz = np.array([compute_midrank(x) for x in preds])
|
||
|
||
aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n)
|
||
|
||
v01 = (tz[:, :m] - tx) / n
|
||
v10 = 1.0 - (tz[:, m:] - ty) / m
|
||
|
||
if v01.shape[0] > 1:
|
||
sx = np.cov(v01)
|
||
sy = np.cov(v10)
|
||
else:
|
||
# Single-classifier case: compute row-wise variance (do not flatten).
|
||
var_v01 = float(np.var(v01, axis=1, ddof=1)[0])
|
||
var_v10 = float(np.var(v10, axis=1, ddof=1)[0])
|
||
sx = np.array([[var_v01]])
|
||
sy = np.array([[var_v10]])
|
||
delong_cov = sx / m + sy / n
|
||
return aucs, delong_cov
|
||
|
||
|
||
def calc_auc_variance(ground_truth: np.ndarray, predictions: np.ndarray) -> Tuple[float, float]:
|
||
y = np.asarray(ground_truth, dtype=int)
|
||
p = np.asarray(predictions, dtype=float)
|
||
if y.ndim != 1 or p.ndim != 1 or y.shape[0] != p.shape[0]:
|
||
raise ValueError("calc_auc_variance expects 1D arrays of equal length")
|
||
|
||
m = int(np.sum(y == 1))
|
||
n = int(np.sum(y == 0))
|
||
if m == 0 or n == 0:
|
||
return float("nan"), float("nan")
|
||
|
||
order = np.argsort(-y) # positives first
|
||
preds_sorted = p[order]
|
||
aucs, cov = fastDeLong(preds_sorted[np.newaxis, :], m)
|
||
auc = float(aucs[0])
|
||
var = float(cov[0, 0])
|
||
return auc, var
|
||
|
||
|
||
def delong_ci(ground_truth: np.ndarray, predictions: np.ndarray, alpha: float = 0.95) -> Tuple[float, float, float]:
|
||
"""Return (auc, ci_low, ci_high) using DeLong variance and normal CI."""
|
||
auc, var = calc_auc_variance(ground_truth, predictions)
|
||
if not np.isfinite(var) or var <= 0:
|
||
print("WARNING: DeLong variance is non-positive or NaN; CI set to NaN")
|
||
return float(auc), float("nan"), float("nan")
|
||
|
||
sd = math.sqrt(var)
|
||
z = statistics.NormalDist().inv_cdf(1.0 - (1.0 - float(alpha)) / 2.0)
|
||
lo = max(0.0, auc - z * sd)
|
||
hi = min(1.0, auc + z * sd)
|
||
return float(auc), float(lo), float(hi)
|
||
|
||
|
||
def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||
"""Rank-based ROC AUC via Mann–Whitney U statistic (ties handled by midranks).
|
||
|
||
Returns NaN for degenerate labels.
|
||
"""
|
||
y = np.asarray(y_true, dtype=int)
|
||
s = np.asarray(y_score, dtype=float)
|
||
if y.ndim != 1 or s.ndim != 1 or y.shape[0] != s.shape[0]:
|
||
raise ValueError("roc_auc_rank expects 1D arrays of equal length")
|
||
m = int(np.sum(y == 1))
|
||
n = int(np.sum(y == 0))
|
||
if m == 0 or n == 0:
|
||
return float("nan")
|
||
|
||
ranks = compute_midrank(s)
|
||
sum_pos = float(np.sum(ranks[y == 1]))
|
||
auc = (sum_pos - m * (m + 1) / 2.0) / (m * n)
|
||
return float(auc)
|
||
|
||
|
||
def bootstrap_auc_ci(
|
||
scores: np.ndarray,
|
||
labels: np.ndarray,
|
||
n_bootstrap: int,
|
||
alpha: float = 0.95,
|
||
seed: int = 0,
|
||
) -> Tuple[float, float, float]:
|
||
"""Bootstrap CI for ROC AUC (percentile)."""
|
||
rng = np.random.default_rng(int(seed))
|
||
scores = np.asarray(scores, dtype=float)
|
||
labels = np.asarray(labels, dtype=int)
|
||
n = labels.shape[0]
|
||
if n == 0 or np.all(labels == labels[0]):
|
||
print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN")
|
||
return float("nan"), float("nan"), float("nan")
|
||
|
||
auc_full = roc_auc_rank(labels, scores)
|
||
if not np.isfinite(auc_full):
|
||
print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN")
|
||
return float("nan"), float("nan"), float("nan")
|
||
|
||
aucs: List[float] = []
|
||
for _ in range(int(n_bootstrap)):
|
||
idx = rng.integers(0, n, size=n)
|
||
yb = labels[idx]
|
||
if np.all(yb == yb[0]):
|
||
continue
|
||
pb = scores[idx]
|
||
auc = roc_auc_rank(yb, pb)
|
||
if np.isfinite(auc):
|
||
aucs.append(float(auc))
|
||
|
||
if len(aucs) < 10:
|
||
print("WARNING: bootstrap AUC CI has too few valid resamples; CI set to NaN")
|
||
return float(auc_full), float("nan"), float("nan")
|
||
|
||
lo_q = (1.0 - float(alpha)) / 2.0
|
||
hi_q = 1.0 - lo_q
|
||
lo = float(np.quantile(aucs, lo_q))
|
||
hi = float(np.quantile(aucs, hi_q))
|
||
return float(auc_full), lo, hi
|
||
|
||
|
||
def brier_score(p: np.ndarray, y: np.ndarray) -> float:
|
||
p = np.asarray(p, dtype=float)
|
||
y = np.asarray(y, dtype=float)
|
||
return float(np.mean((p - y) ** 2))
|
||
|
||
|
||
def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]:
|
||
p = np.asarray(p, dtype=float)
|
||
y = np.asarray(y, dtype=float)
|
||
|
||
# guard
|
||
if p.size == 0:
|
||
return {"bins": [], "ece": float("nan"), "ici": float("nan")}
|
||
|
||
edges = np.quantile(p, np.linspace(0.0, 1.0, n_bins + 1))
|
||
# make strictly increasing where possible
|
||
edges[0] = -np.inf
|
||
edges[-1] = np.inf
|
||
|
||
bins = []
|
||
ece = 0.0
|
||
ici_accum = 0.0
|
||
n = p.shape[0]
|
||
|
||
for i in range(n_bins):
|
||
mask = (p > edges[i]) & (p <= edges[i + 1])
|
||
if not np.any(mask):
|
||
continue
|
||
p_mean = float(np.mean(p[mask]))
|
||
y_mean = float(np.mean(y[mask]))
|
||
frac = float(np.mean(mask))
|
||
bins.append({"bin": i, "p_mean": p_mean,
|
||
"y_mean": y_mean, "n": int(mask.sum())})
|
||
ece += frac * abs(p_mean - y_mean)
|
||
ici_accum += abs(p_mean - y_mean)
|
||
|
||
ici = ici_accum / max(len(bins), 1)
|
||
return {"bins": bins, "ece": float(ece), "ici": float(ici)}
|
||
|
||
|
||
def count_ever_after_context_anytime(
|
||
loader: DataLoader,
|
||
offset_years: float,
|
||
n_disease: int,
|
||
device: str,
|
||
) -> Tuple[np.ndarray, int]:
|
||
"""Count per-person ever-occurrence for each disease after the prediction context.
|
||
|
||
Returns counts[k] = number of individuals with disease k at least once after context.
|
||
"""
|
||
counts = torch.zeros((n_disease,), dtype=torch.long, device=device)
|
||
n_total_eval = 0
|
||
for batch in loader:
|
||
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
||
event_seq = event_seq.to(device)
|
||
time_seq = time_seq.to(device)
|
||
|
||
keep, t_ctx, _ = select_context_indices(
|
||
event_seq, time_seq, offset_years)
|
||
if not keep.any():
|
||
continue
|
||
|
||
n_total_eval += int(keep.sum().item())
|
||
event_seq = event_seq[keep]
|
||
t_ctx = t_ctx[keep]
|
||
|
||
B, L = event_seq.shape
|
||
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
||
future = (idxs > t_ctx.unsqueeze(1)) & (
|
||
event_seq >= 2) & (event_seq != 0)
|
||
if not future.any():
|
||
continue
|
||
|
||
b_idx, t_idx = future.nonzero(as_tuple=True)
|
||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||
|
||
# unique per (person, disease) to count per-person ever-occurrence
|
||
key = b_idx.to(torch.long) * int(n_disease) + disease_ids
|
||
uniq = torch.unique(key)
|
||
uniq_disease = uniq % int(n_disease)
|
||
counts.scatter_add_(0, uniq_disease, torch.ones_like(
|
||
uniq_disease, dtype=torch.long))
|
||
|
||
return counts.detach().cpu().numpy(), int(n_total_eval)
|
||
|
||
|
||
# ============================================================
|
||
# Evaluation core
|
||
# ============================================================
|
||
|
||
def instantiate_model_and_head(
|
||
cfg: Dict[str, Any],
|
||
dataset: HealthDataset,
|
||
device: str,
|
||
checkpoint_path: str = "",
|
||
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]:
|
||
model_type = str(cfg["model_type"])
|
||
loss_type = str(cfg["loss_type"])
|
||
|
||
if loss_type == "exponential":
|
||
out_dims = [dataset.n_disease]
|
||
elif loss_type == "discrete_time_cif":
|
||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||
out_dims = [dataset.n_disease + 1, len(bin_edges)]
|
||
else:
|
||
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
|
||
|
||
if model_type == "delphi_fork":
|
||
backbone = DelphiFork(
|
||
n_disease=dataset.n_disease,
|
||
n_tech_tokens=2,
|
||
n_embd=int(cfg["n_embd"]),
|
||
n_head=int(cfg["n_head"]),
|
||
n_layer=int(cfg["n_layer"]),
|
||
pdrop=float(cfg.get("pdrop", 0.0)),
|
||
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
|
||
n_cont=dataset.n_cont,
|
||
n_cate=dataset.n_cate,
|
||
cate_dims=dataset.cate_dims,
|
||
).to(device)
|
||
elif model_type == "sap_delphi":
|
||
# Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path.
|
||
emb_path = cfg.get("pretrained_emb_path", None)
|
||
if emb_path in {"", None}:
|
||
emb_path = cfg.get("pretrained_emd_path", None)
|
||
if emb_path in {"", None}:
|
||
run_dir = os.path.dirname(os.path.abspath(
|
||
checkpoint_path)) if checkpoint_path else ""
|
||
print(
|
||
f"WARNING: SapDelphi pretrained embedding path missing in config "
|
||
f"(expected 'pretrained_emb_path' or 'pretrained_emd_path'). "
|
||
f"checkpoint={checkpoint_path} run_dir={run_dir}"
|
||
)
|
||
backbone = SapDelphi(
|
||
n_disease=dataset.n_disease,
|
||
n_tech_tokens=2,
|
||
n_embd=int(cfg["n_embd"]),
|
||
n_head=int(cfg["n_head"]),
|
||
n_layer=int(cfg["n_layer"]),
|
||
pdrop=float(cfg.get("pdrop", 0.0)),
|
||
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
|
||
n_cont=dataset.n_cont,
|
||
n_cate=dataset.n_cate,
|
||
cate_dims=dataset.cate_dims,
|
||
pretrained_weights_path=emb_path,
|
||
freeze_embeddings=True,
|
||
).to(device)
|
||
else:
|
||
raise ValueError(f"Unsupported model_type: {model_type}")
|
||
|
||
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
|
||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||
return backbone, head, loss_type, bin_edges
|
||
|
||
|
||
@torch.no_grad()
|
||
def predict_cifs_for_model(
|
||
backbone: torch.nn.Module,
|
||
head: torch.nn.Module,
|
||
loss_type: str,
|
||
bin_edges: Sequence[float],
|
||
loader: DataLoader,
|
||
device: str,
|
||
offset_years: float,
|
||
eval_horizons: Sequence[float],
|
||
top_cause_ids: np.ndarray,
|
||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||
"""Run model and produce:
|
||
|
||
Returns:
|
||
allcause_risk: (N,H)
|
||
cause_cif: (N, topK, H)
|
||
cif_full: (N, K, H)
|
||
survival: (N, H)
|
||
sex: (N,)
|
||
y_allcause_tau: (N,H)
|
||
y_cause_ever_anytime: (N, topK)
|
||
y_cause_within_tau: (N, topK, H)
|
||
y_cause_within_tau_max: (N, topK)
|
||
|
||
NOTE:
|
||
- y_cause_ever_anytime is Delphi2M-compatible case/control label.
|
||
- y_cause_within_tau_* corresponds to within-horizon labels (kept for legacy/secondary AUC).
|
||
"""
|
||
backbone.eval()
|
||
head.eval()
|
||
|
||
# We will accumulate in CPU lists, then concat.
|
||
allcause_list: List[np.ndarray] = []
|
||
cause_cif_list: List[np.ndarray] = []
|
||
cif_full_list: List[np.ndarray] = []
|
||
survival_list: List[np.ndarray] = []
|
||
sex_list: List[np.ndarray] = []
|
||
y_all_list: List[np.ndarray] = []
|
||
y_cause_ever_any_list: List[np.ndarray] = []
|
||
y_cause_within_list: List[np.ndarray] = []
|
||
y_cause_within_tau_max_list: List[np.ndarray] = []
|
||
|
||
tau_max = float(max(eval_horizons))
|
||
top_cause_ids_t = torch.tensor(
|
||
top_cause_ids, dtype=torch.long, device=device)
|
||
|
||
# Efficiency: pre-create horizons tensor once per model (on device) and vectorize comparisons.
|
||
eval_horizons_t = torch.tensor(
|
||
list(eval_horizons), device=device, dtype=torch.float32).view(1, -1)
|
||
|
||
for batch in loader:
|
||
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
||
event_seq = event_seq.to(device)
|
||
time_seq = time_seq.to(device)
|
||
cont_feats = cont_feats.to(device)
|
||
cate_feats = cate_feats.to(device)
|
||
sexes = sexes.to(device)
|
||
|
||
keep, t_ctx, _ = select_context_indices(
|
||
event_seq, time_seq, offset_years)
|
||
if not keep.any():
|
||
continue
|
||
|
||
# filter batch
|
||
event_seq = event_seq[keep]
|
||
time_seq = time_seq[keep]
|
||
cont_feats = cont_feats[keep]
|
||
cate_feats = cate_feats[keep]
|
||
sexes_k = sexes[keep]
|
||
t_ctx = t_ctx[keep]
|
||
|
||
h = backbone(event_seq, time_seq, sexes_k,
|
||
cont_feats, cate_feats) # (B,L,D)
|
||
b = torch.arange(h.size(0), device=device)
|
||
c = h[b, t_ctx] # (B,D)
|
||
logits = head(c)
|
||
|
||
if loss_type == "exponential":
|
||
cif_full, survival = cifs_from_exponential_logits(
|
||
logits, eval_horizons, return_survival=True) # (B,K,H), (B,H)
|
||
elif loss_type == "discrete_time_cif":
|
||
cif_full, survival = cifs_from_discrete_time_logits(
|
||
# (B,K,H), (B,H)
|
||
logits, bin_edges, eval_horizons, return_survival=True)
|
||
else:
|
||
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
||
|
||
allcause = cif_full.sum(dim=1) # (B,H)
|
||
cause_cif = cif_full.index_select(
|
||
dim=1, index=top_cause_ids_t) # (B,topK,H)
|
||
|
||
# outcomes
|
||
dt_next, _cause_next = next_event_after_context(
|
||
event_seq, time_seq, t_ctx)
|
||
y_all = (dt_next.view(-1, 1) <= eval_horizons_t).to(torch.float32)
|
||
|
||
# Delphi2M-compatible ever label (does not depend on horizon)
|
||
y_ever_any = multi_hot_ever_after_context_anytime(
|
||
event_seq=event_seq,
|
||
t_ctx=t_ctx,
|
||
n_disease=int(cif_full.size(1)),
|
||
)
|
||
y_ever_any_top = y_ever_any.index_select(
|
||
dim=1, index=top_cause_ids_t).to(torch.float32)
|
||
|
||
# Within-horizon labels for cause-specific CIF quality + legacy AUC
|
||
n_disease = int(cif_full.size(1))
|
||
y_within_top = torch.stack(
|
||
[
|
||
multi_hot_selected_causes_within_horizon(
|
||
event_seq=event_seq,
|
||
time_seq=time_seq,
|
||
t_ctx=t_ctx,
|
||
tau_years=float(tau),
|
||
cause_ids=top_cause_ids_t,
|
||
n_disease=n_disease,
|
||
).to(torch.float32)
|
||
for tau in eval_horizons
|
||
],
|
||
dim=2,
|
||
) # (B,topK,H)
|
||
y_within_tau_max_top = multi_hot_selected_causes_within_horizon(
|
||
event_seq=event_seq,
|
||
time_seq=time_seq,
|
||
t_ctx=t_ctx,
|
||
tau_years=tau_max,
|
||
cause_ids=top_cause_ids_t,
|
||
n_disease=n_disease,
|
||
).to(torch.float32)
|
||
|
||
allcause_list.append(allcause.detach().cpu().numpy())
|
||
cause_cif_list.append(cause_cif.detach().cpu().numpy())
|
||
cif_full_list.append(cif_full.detach().cpu().numpy())
|
||
survival_list.append(survival.detach().cpu().numpy())
|
||
sex_list.append(sexes_k.detach().cpu().numpy())
|
||
y_all_list.append(y_all.detach().cpu().numpy())
|
||
y_cause_ever_any_list.append(y_ever_any_top.detach().cpu().numpy())
|
||
y_cause_within_list.append(y_within_top.detach().cpu().numpy())
|
||
y_cause_within_tau_max_list.append(
|
||
y_within_tau_max_top.detach().cpu().numpy())
|
||
|
||
if not allcause_list:
|
||
raise RuntimeError(
|
||
"No valid samples for evaluation (all batches filtered out by offset).")
|
||
|
||
allcause_risk = np.concatenate(allcause_list, axis=0)
|
||
cause_cif = np.concatenate(cause_cif_list, axis=0)
|
||
cif_full = np.concatenate(cif_full_list, axis=0)
|
||
survival = np.concatenate(survival_list, axis=0)
|
||
sex = np.concatenate(sex_list, axis=0)
|
||
y_allcause = np.concatenate(y_all_list, axis=0)
|
||
y_cause_ever_any = np.concatenate(y_cause_ever_any_list, axis=0)
|
||
y_cause_within = np.concatenate(y_cause_within_list, axis=0)
|
||
y_cause_within_tau_max = np.concatenate(y_cause_within_tau_max_list, axis=0)
|
||
|
||
return allcause_risk, cause_cif, cif_full, survival, sex, y_allcause, y_cause_ever_any, y_cause_within, y_cause_within_tau_max
|
||
|
||
|
||
def pick_top_causes(y_ever: np.ndarray, top_k: int) -> np.ndarray:
|
||
counts = y_ever.sum(axis=0)
|
||
order = np.argsort(-counts)
|
||
order = order[counts[order] > 0]
|
||
return order[:top_k]
|
||
|
||
|
||
def evaluate_one_model(
|
||
model_name: str,
|
||
allcause_risk: np.ndarray,
|
||
cause_cif: np.ndarray,
|
||
sex: np.ndarray,
|
||
y_allcause: np.ndarray,
|
||
y_cause_ever_anytime: np.ndarray,
|
||
y_cause_within_tau: np.ndarray,
|
||
y_cause_within_tau_max: np.ndarray,
|
||
eval_horizons: Sequence[float],
|
||
top_cause_ids: np.ndarray,
|
||
out_rows: List[Dict[str, Any]],
|
||
calib_rows: List[Dict[str, Any]],
|
||
auc_ci_method: str,
|
||
bootstrap_n: int,
|
||
n_calib_bins: int = 10,
|
||
) -> None:
|
||
H = len(eval_horizons)
|
||
|
||
# Task B (all-cause): Brier + AUC + calibration per horizon
|
||
for h_i, tau in enumerate(eval_horizons):
|
||
p = allcause_risk[:, h_i]
|
||
y = y_allcause[:, h_i]
|
||
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "allcause_brier",
|
||
"horizon": float(tau),
|
||
"cause": "",
|
||
"value": brier_score(p, y),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
|
||
if auc_ci_method == "none":
|
||
auc, lo, hi = float("nan"), float("nan"), float("nan")
|
||
auc = float("nan")
|
||
elif auc_ci_method == "bootstrap":
|
||
auc, lo, hi = bootstrap_auc_ci(
|
||
p, y, n_bootstrap=bootstrap_n, alpha=0.95)
|
||
else:
|
||
auc, lo, hi = delong_ci(y, p, alpha=0.95)
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "allcause_auc",
|
||
"horizon": float(tau),
|
||
"cause": "",
|
||
"value": auc,
|
||
"ci_low": lo,
|
||
"ci_high": hi,
|
||
}
|
||
)
|
||
|
||
cal = calibration_deciles(p, y, n_bins=n_calib_bins)
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "allcause_ece",
|
||
"horizon": float(tau),
|
||
"cause": "",
|
||
"value": cal["ece"],
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "allcause_ici",
|
||
"horizon": float(tau),
|
||
"cause": "",
|
||
"value": cal["ici"],
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
|
||
# Write calibration bins into a separate CSV (always for all-cause).
|
||
for binfo in cal.get("bins", []):
|
||
calib_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"task": "all_cause",
|
||
"horizon": float(tau),
|
||
"cause_id": -1,
|
||
"bin_index": int(binfo["bin"]),
|
||
"p_mean": float(binfo["p_mean"]),
|
||
"y_mean": float(binfo["y_mean"]),
|
||
"n_in_bin": int(binfo["n"]),
|
||
}
|
||
)
|
||
|
||
# Stratification by sex
|
||
for s_val in [0, 1]:
|
||
m = sex == s_val
|
||
if np.sum(m) < 10:
|
||
continue
|
||
p_s = p[m]
|
||
y_s = y[m]
|
||
if auc_ci_method == "none":
|
||
auc_s, lo_s, hi_s = float("nan"), float("nan"), float("nan")
|
||
elif auc_ci_method == "bootstrap":
|
||
auc_s, lo_s, hi_s = bootstrap_auc_ci(
|
||
p_s, y_s, n_bootstrap=bootstrap_n, alpha=0.95)
|
||
else:
|
||
auc_s, lo_s, hi_s = delong_ci(y_s, p_s, alpha=0.95)
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": f"allcause_auc_sex{s_val}",
|
||
"horizon": float(tau),
|
||
"cause": "",
|
||
"value": auc_s,
|
||
"ci_low": lo_s,
|
||
"ci_high": hi_s,
|
||
}
|
||
)
|
||
|
||
# Task A (Delphi2M-compatible discrimination): per-cause AUC with EVER labels
|
||
# case/control is defined by whether the disease appears ANYTIME after context.
|
||
tau_max = float(max(eval_horizons))
|
||
p_tau_max = cause_cif[:, :, -1] # (N, topK)
|
||
|
||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||
yk = y_cause_ever_anytime[:, j]
|
||
pk = p_tau_max[:, j]
|
||
if auc_ci_method == "none":
|
||
auc, lo, hi = float("nan"), float("nan"), float("nan")
|
||
elif auc_ci_method == "bootstrap":
|
||
auc, lo, hi = bootstrap_auc_ci(
|
||
pk, yk, n_bootstrap=bootstrap_n, alpha=0.95)
|
||
else:
|
||
auc, lo, hi = delong_ci(yk, pk, alpha=0.95)
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "cause_auc_ever",
|
||
"horizon": tau_max,
|
||
"cause": int(cause_id),
|
||
"value": auc,
|
||
"ci_low": lo,
|
||
"ci_high": hi,
|
||
}
|
||
)
|
||
|
||
# Keep the existing tau-window AUC as a separate metric (do not remove).
|
||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||
yk = y_cause_within_tau_max[:, j]
|
||
pk = p_tau_max[:, j]
|
||
if auc_ci_method == "none":
|
||
auc, lo, hi = float("nan"), float("nan"), float("nan")
|
||
elif auc_ci_method == "bootstrap":
|
||
auc, lo, hi = bootstrap_auc_ci(
|
||
pk, yk, n_bootstrap=bootstrap_n, alpha=0.95)
|
||
else:
|
||
auc, lo, hi = delong_ci(yk, pk, alpha=0.95)
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "cause_auc",
|
||
"horizon": tau_max,
|
||
"cause": int(cause_id),
|
||
"value": auc,
|
||
"ci_low": lo,
|
||
"ci_high": hi,
|
||
}
|
||
)
|
||
|
||
# Task B additions: cause-specific Brier + calibration curves at tau=3.84 and 10.0
|
||
tau_targets = [3.84, 10.0]
|
||
horizon_to_idx = {float(t): i for i, t in enumerate(
|
||
[float(x) for x in eval_horizons])}
|
||
for tau in tau_targets:
|
||
if float(tau) not in horizon_to_idx:
|
||
continue
|
||
h_idx = horizon_to_idx[float(tau)]
|
||
p_tau = cause_cif[:, :, h_idx] # (N, topK)
|
||
y_tau = y_cause_within_tau[:, :, h_idx] # (N, topK)
|
||
|
||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||
p = p_tau[:, j]
|
||
y = y_tau[:, j]
|
||
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "cause_brier",
|
||
"horizon": float(tau),
|
||
"cause": int(cause_id),
|
||
"value": brier_score(p, y),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
|
||
cal = calibration_deciles(p, y)
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "cause_ece",
|
||
"horizon": float(tau),
|
||
"cause": int(cause_id),
|
||
"value": cal["ece"],
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
out_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "cause_ici",
|
||
"horizon": float(tau),
|
||
"cause": int(cause_id),
|
||
"value": cal["ici"],
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
|
||
# Write cause calibration bins into separate CSV only for tau targets.
|
||
for binfo in cal.get("bins", []):
|
||
calib_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"task": "cause_k",
|
||
"horizon": float(tau),
|
||
"cause_id": int(cause_id),
|
||
"bin_index": int(binfo["bin"]),
|
||
"p_mean": float(binfo["p_mean"]),
|
||
"y_mean": float(binfo["y_mean"]),
|
||
"n_in_bin": int(binfo["n"]),
|
||
}
|
||
)
|
||
|
||
|
||
def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None:
|
||
fieldnames = [
|
||
"model_name",
|
||
"task",
|
||
"horizon",
|
||
"cause_id",
|
||
"bin_index",
|
||
"p_mean",
|
||
"y_mean",
|
||
"n_in_bin",
|
||
]
|
||
with open(path, "w", newline="") as f:
|
||
w = csv.DictWriter(f, fieldnames=fieldnames)
|
||
w.writeheader()
|
||
for r in rows:
|
||
w.writerow(r)
|
||
|
||
|
||
def write_results_csv(path: str, rows: List[Dict[str, Any]]) -> None:
|
||
fieldnames = [
|
||
"model_name",
|
||
"metric_name",
|
||
"horizon",
|
||
"cause",
|
||
"value",
|
||
"ci_low",
|
||
"ci_high",
|
||
]
|
||
with open(path, "w", newline="") as f:
|
||
w = csv.DictWriter(f, fieldnames=fieldnames)
|
||
w.writeheader()
|
||
for r in rows:
|
||
w.writerow(r)
|
||
|
||
|
||
def _make_eval_tag(split: str, offset_years: float) -> str:
|
||
"""Short tag for filenames written into run directories."""
|
||
off = f"{float(offset_years):.4f}".rstrip("0").rstrip(".")
|
||
return f"{split}_offset{off}y"
|
||
|
||
|
||
def main() -> int:
|
||
ap = argparse.ArgumentParser(
|
||
description="Unified downstream evaluation via CIFs")
|
||
ap.add_argument("--models_json", type=str, required=True,
|
||
help="Path to models list JSON")
|
||
ap.add_argument("--data_prefix", type=str,
|
||
default="ukb", help="Dataset prefix")
|
||
ap.add_argument("--split", type=str, default="test",
|
||
choices=["train", "val", "test", "all"], help="Which split to evaluate")
|
||
ap.add_argument("--offset_years", type=float, default=0.5,
|
||
help="Anti-leakage offset (years)")
|
||
ap.add_argument("--eval_horizons", type=float,
|
||
nargs="*", default=DEFAULT_EVAL_HORIZONS)
|
||
ap.add_argument("--top_k_causes", type=int, default=50)
|
||
ap.add_argument("--batch_size", type=int, default=128)
|
||
ap.add_argument("--num_workers", type=int, default=0)
|
||
ap.add_argument("--seed", type=int, default=123)
|
||
ap.add_argument("--device", type=str,
|
||
default="cuda" if torch.cuda.is_available() else "cpu")
|
||
ap.add_argument("--out_csv", type=str, default="eval_results.csv")
|
||
ap.add_argument("--out_meta_json", type=str, default="eval_meta.json")
|
||
|
||
# Integrity checks
|
||
ap.add_argument("--integrity_strict", action="store_true", default=False)
|
||
ap.add_argument("--integrity_tol", type=float, default=1e-6)
|
||
|
||
# AUC CI methods
|
||
ap.add_argument(
|
||
"--auc_ci_method",
|
||
type=str,
|
||
default="delong",
|
||
choices=["delong", "bootstrap", "none"],
|
||
)
|
||
ap.add_argument("--bootstrap_n", type=int, default=2000)
|
||
args = ap.parse_args()
|
||
|
||
set_deterministic(args.seed)
|
||
|
||
specs = load_models_json(args.models_json)
|
||
if not specs:
|
||
raise ValueError("No models provided")
|
||
|
||
# Determine top-K causes from the evaluation split only (model-agnostic).
|
||
first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path)
|
||
cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [
|
||
"bmi", "smoking", "alcohol"]
|
||
dataset_for_top = HealthDataset(
|
||
data_prefix=args.data_prefix, covariate_list=cov_list)
|
||
subset_for_top = build_eval_subset(
|
||
dataset_for_top,
|
||
train_ratio=float(first_cfg.get("train_ratio", 0.7)),
|
||
val_ratio=float(first_cfg.get("val_ratio", 0.15)),
|
||
seed=int(first_cfg.get("random_seed", 42)),
|
||
split=args.split,
|
||
)
|
||
loader_top = DataLoader(
|
||
subset_for_top,
|
||
batch_size=args.batch_size,
|
||
shuffle=False,
|
||
num_workers=args.num_workers,
|
||
collate_fn=health_collate_fn,
|
||
)
|
||
|
||
counts, n_total_eval = count_ever_after_context_anytime(
|
||
loader=loader_top,
|
||
offset_years=args.offset_years,
|
||
n_disease=dataset_for_top.n_disease,
|
||
device=args.device,
|
||
)
|
||
order = np.argsort(-counts)
|
||
order = order[counts[order] > 0]
|
||
top_cause_ids = order[: args.top_k_causes]
|
||
|
||
# Record top-cause counts under Delphi2M-compatible EVER label.
|
||
top_causes_meta: List[Dict[str, Any]] = []
|
||
for k in top_cause_ids.tolist():
|
||
n_case = int(counts[int(k)])
|
||
top_causes_meta.append(
|
||
{
|
||
"cause_id": int(k),
|
||
"n_case_ever": n_case,
|
||
"n_control_ever": int(n_total_eval - n_case),
|
||
"n_total_eval": int(n_total_eval),
|
||
}
|
||
)
|
||
|
||
rows: List[Dict[str, Any]] = []
|
||
calib_rows: List[Dict[str, Any]] = []
|
||
|
||
# Track per-model integrity status for meta JSON.
|
||
integrity_meta: Dict[str, Any] = {}
|
||
|
||
# Evaluate each model
|
||
for spec in specs:
|
||
run_dir = os.path.dirname(os.path.abspath(spec.checkpoint_path))
|
||
tag = _make_eval_tag(args.split, float(args.offset_years))
|
||
|
||
# Remember list offsets so we can write per-model slices to the model's run_dir.
|
||
rows_start = len(rows)
|
||
calib_start = len(calib_rows)
|
||
|
||
cfg = load_train_config_for_checkpoint(spec.checkpoint_path)
|
||
|
||
cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [
|
||
"bmi", "smoking", "alcohol"]
|
||
dataset = HealthDataset(
|
||
data_prefix=args.data_prefix, covariate_list=cov_list)
|
||
subset = build_eval_subset(
|
||
dataset,
|
||
train_ratio=float(cfg.get("train_ratio", 0.7)),
|
||
val_ratio=float(cfg.get("val_ratio", 0.15)),
|
||
seed=int(cfg.get("random_seed", 42)),
|
||
split=args.split,
|
||
)
|
||
loader = DataLoader(
|
||
subset,
|
||
batch_size=args.batch_size,
|
||
shuffle=False,
|
||
num_workers=args.num_workers,
|
||
collate_fn=health_collate_fn,
|
||
)
|
||
|
||
backbone, head, loss_type, bin_edges = instantiate_model_and_head(
|
||
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
|
||
ckpt = torch.load(spec.checkpoint_path, map_location=args.device)
|
||
backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
|
||
head.load_state_dict(ckpt["head_state_dict"], strict=True)
|
||
|
||
(
|
||
allcause_risk,
|
||
cause_cif,
|
||
cif_full,
|
||
survival,
|
||
sex,
|
||
y_allcause,
|
||
y_cause_ever_anytime,
|
||
y_cause_within_tau,
|
||
y_cause_within_tau_max,
|
||
) = predict_cifs_for_model(
|
||
backbone,
|
||
head,
|
||
loss_type,
|
||
bin_edges,
|
||
loader,
|
||
args.device,
|
||
args.offset_years,
|
||
args.eval_horizons,
|
||
top_cause_ids,
|
||
)
|
||
|
||
# CIF integrity checks before metrics.
|
||
integrity_ok, integrity_notes = check_cif_integrity(
|
||
cif_full,
|
||
args.eval_horizons,
|
||
tol=float(args.integrity_tol),
|
||
name=spec.name,
|
||
strict=bool(args.integrity_strict),
|
||
survival=survival,
|
||
)
|
||
integrity_meta[spec.name] = {
|
||
"integrity_ok": bool(integrity_ok),
|
||
"integrity_notes": integrity_notes,
|
||
}
|
||
|
||
evaluate_one_model(
|
||
model_name=spec.name,
|
||
allcause_risk=allcause_risk,
|
||
cause_cif=cause_cif,
|
||
sex=sex,
|
||
y_allcause=y_allcause,
|
||
y_cause_ever_anytime=y_cause_ever_anytime,
|
||
y_cause_within_tau=y_cause_within_tau,
|
||
y_cause_within_tau_max=y_cause_within_tau_max,
|
||
eval_horizons=args.eval_horizons,
|
||
top_cause_ids=top_cause_ids,
|
||
out_rows=rows,
|
||
calib_rows=calib_rows,
|
||
auc_ci_method=str(args.auc_ci_method),
|
||
bootstrap_n=int(args.bootstrap_n),
|
||
)
|
||
|
||
# Optionally write top-cause counts into the main results CSV as metric rows.
|
||
for tc in top_causes_meta:
|
||
rows.append(
|
||
{
|
||
"model_name": spec.name,
|
||
"metric_name": "topcause_n_case_ever",
|
||
"horizon": "",
|
||
"cause": int(tc["cause_id"]),
|
||
"value": int(tc["n_case_ever"]),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
rows.append(
|
||
{
|
||
"model_name": spec.name,
|
||
"metric_name": "topcause_n_control_ever",
|
||
"horizon": "",
|
||
"cause": int(tc["cause_id"]),
|
||
"value": int(tc["n_control_ever"]),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
rows.append(
|
||
{
|
||
"model_name": spec.name,
|
||
"metric_name": "topcause_n_total_eval",
|
||
"horizon": "",
|
||
"cause": int(tc["cause_id"]),
|
||
"value": int(tc["n_total_eval"]),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
|
||
# Write per-model results into the model's run directory.
|
||
model_rows = rows[rows_start:]
|
||
model_calib_rows = calib_rows[calib_start:]
|
||
model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv")
|
||
model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv")
|
||
model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json")
|
||
|
||
write_results_csv(model_out_csv, model_rows)
|
||
write_calibration_bins_csv(model_calib_csv, model_calib_rows)
|
||
|
||
model_meta = {
|
||
"model_name": spec.name,
|
||
"checkpoint_path": spec.checkpoint_path,
|
||
"run_dir": run_dir,
|
||
"split": args.split,
|
||
"offset_years": args.offset_years,
|
||
"eval_horizons": [float(x) for x in args.eval_horizons],
|
||
"top_k_causes": int(args.top_k_causes),
|
||
"top_cause_ids": top_cause_ids.tolist(),
|
||
"top_causes": top_causes_meta,
|
||
"integrity": {spec.name: integrity_meta.get(spec.name, {})},
|
||
"paths": {
|
||
"results_csv": model_out_csv,
|
||
"calibration_bins_csv": model_calib_csv,
|
||
},
|
||
}
|
||
with open(model_meta_json, "w") as f:
|
||
json.dump(model_meta, f, indent=2)
|
||
|
||
print(f"Wrote per-model results to {model_out_csv}")
|
||
|
||
write_results_csv(args.out_csv, rows)
|
||
|
||
# Write calibration curve points to a separate CSV.
|
||
out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "."
|
||
calib_csv_path = os.path.join(out_dir, "calibration_bins.csv")
|
||
write_calibration_bins_csv(calib_csv_path, calib_rows)
|
||
|
||
meta = {
|
||
"split": args.split,
|
||
"offset_years": args.offset_years,
|
||
"eval_horizons": [float(x) for x in args.eval_horizons],
|
||
"top_k_causes": int(args.top_k_causes),
|
||
"top_cause_ids": top_cause_ids.tolist(),
|
||
"top_causes": top_causes_meta,
|
||
"integrity": integrity_meta,
|
||
"notes": {
|
||
"task_a_label": "Delphi2M-compatible: disease occurs ANYTIME after context (ever in remaining sequence)",
|
||
"task_a_legacy_label": "Secondary: disease occurs within tau_max after context",
|
||
"task_b_label": "all-cause event within horizon (equivalent to next disease event within horizon)",
|
||
"warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.",
|
||
},
|
||
}
|
||
with open(args.out_meta_json, "w") as f:
|
||
json.dump(meta, f, indent=2)
|
||
|
||
print(f"Wrote {args.out_csv} with {len(rows)} rows")
|
||
print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows")
|
||
print(f"Wrote {args.out_meta_json}")
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|