2413 lines
85 KiB
Python
2413 lines
85 KiB
Python
import argparse
|
||
import csv
|
||
import json
|
||
import math
|
||
import os
|
||
import random
|
||
import sys
|
||
import time
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from dataclasses import dataclass
|
||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from torch.utils.data import DataLoader, random_split
|
||
|
||
from dataset import HealthDataset, health_collate_fn
|
||
from model import DelphiFork, SapDelphi, SimpleHead
|
||
|
||
|
||
# ============================================================
|
||
# Constants / defaults (aligned with evaluate_prompt.md)
|
||
# ============================================================
|
||
DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")]
|
||
DEFAULT_EVAL_HORIZONS = [0.25, 0.5, 1.0, 2.0, 5.0, 10.0]
|
||
DAYS_PER_YEAR = 365.25
|
||
DEFAULT_DEATH_CAUSE_ID = 1256
|
||
|
||
|
||
# ============================================================
|
||
# Model specs
|
||
# ============================================================
|
||
@dataclass(frozen=True)
|
||
class ModelSpec:
|
||
name: str
|
||
model_type: str # delphi_fork | sap_delphi
|
||
loss_type: str # exponential | discrete_time_cif | lognormal_basis_binned_hazard_cif
|
||
full_cov: bool
|
||
checkpoint_path: str
|
||
|
||
|
||
# ============================================================
|
||
# Determinism
|
||
# ============================================================
|
||
|
||
def set_deterministic(seed: int) -> None:
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
torch.cuda.manual_seed_all(seed)
|
||
torch.backends.cudnn.deterministic = True
|
||
torch.backends.cudnn.benchmark = False
|
||
|
||
|
||
# ============================================================
|
||
# Utilities
|
||
# ============================================================
|
||
|
||
def _parse_bool(x: Any) -> bool:
|
||
if isinstance(x, bool):
|
||
return x
|
||
s = str(x).strip().lower()
|
||
if s in {"true", "1", "yes", "y"}:
|
||
return True
|
||
if s in {"false", "0", "no", "n"}:
|
||
return False
|
||
raise ValueError(f"Cannot parse boolean: {x!r}")
|
||
|
||
|
||
def load_models_json(path: str) -> List[ModelSpec]:
|
||
with open(path, "r") as f:
|
||
data = json.load(f)
|
||
if not isinstance(data, list):
|
||
raise ValueError("models_json must be a list of model entries")
|
||
|
||
specs: List[ModelSpec] = []
|
||
for row in data:
|
||
specs.append(
|
||
ModelSpec(
|
||
name=str(row["name"]),
|
||
model_type=str(row["model_type"]),
|
||
loss_type=str(row["loss_type"]),
|
||
full_cov=_parse_bool(row["full_cov"]),
|
||
checkpoint_path=str(row["checkpoint_path"]),
|
||
)
|
||
)
|
||
return specs
|
||
|
||
|
||
def load_train_config_for_checkpoint(checkpoint_path: str) -> Dict[str, Any]:
|
||
run_dir = os.path.dirname(os.path.abspath(checkpoint_path))
|
||
cfg_path = os.path.join(run_dir, "train_config.json")
|
||
with open(cfg_path, "r") as f:
|
||
cfg = json.load(f)
|
||
return cfg
|
||
|
||
|
||
def build_eval_subset(
|
||
dataset: HealthDataset,
|
||
train_ratio: float,
|
||
val_ratio: float,
|
||
seed: int,
|
||
split: str,
|
||
):
|
||
n_total = len(dataset)
|
||
n_train = int(n_total * train_ratio)
|
||
n_val = int(n_total * val_ratio)
|
||
n_test = n_total - n_train - n_val
|
||
|
||
train_ds, val_ds, test_ds = random_split(
|
||
dataset,
|
||
[n_train, n_val, n_test],
|
||
generator=torch.Generator().manual_seed(seed),
|
||
)
|
||
|
||
if split == "train":
|
||
return train_ds
|
||
if split == "val":
|
||
return val_ds
|
||
if split == "test":
|
||
return test_ds
|
||
if split == "all":
|
||
return dataset
|
||
raise ValueError("split must be one of: train, val, test, all")
|
||
|
||
|
||
# ============================================================
|
||
# Context selection (anti-leakage)
|
||
# ============================================================
|
||
|
||
def select_context_indices(
|
||
event_seq: torch.Tensor,
|
||
time_seq: torch.Tensor,
|
||
offset_years: float,
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""Select per-sample prediction context index.
|
||
|
||
IMPORTANT SEMANTICS:
|
||
- The last observed token time is treated as the FOLLOW-UP END time.
|
||
- We pick the last valid token with time <= (followup_end_time - offset).
|
||
- We do NOT interpret followup_end_time as an event time.
|
||
|
||
Returns:
|
||
keep_mask: (B,) bool, which samples have a valid context
|
||
t_ctx: (B,) long, index into sequence
|
||
t_ctx_time: (B,) float, time (days) at context
|
||
"""
|
||
# valid tokens are event != 0 (padding is 0)
|
||
valid = event_seq != 0
|
||
lengths = valid.sum(dim=1)
|
||
last_idx = torch.clamp(lengths - 1, min=0)
|
||
|
||
b = torch.arange(event_seq.size(0), device=event_seq.device)
|
||
followup_end_time = time_seq[b, last_idx]
|
||
t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR)
|
||
|
||
eligible = valid & (time_seq <= t_cut.unsqueeze(1))
|
||
eligible_counts = eligible.sum(dim=1)
|
||
keep = eligible_counts > 0
|
||
|
||
t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long)
|
||
t_ctx_time = time_seq[b, t_ctx]
|
||
|
||
return keep, t_ctx, t_ctx_time
|
||
|
||
|
||
def next_event_after_context(
|
||
event_seq: torch.Tensor,
|
||
time_seq: torch.Tensor,
|
||
t_ctx: torch.Tensor,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""Return next disease event after context.
|
||
|
||
Returns:
|
||
dt_years: (B,) float, time to next disease in years; +inf if none
|
||
cause: (B,) long, disease id in [0,K) for next event; -1 if none
|
||
"""
|
||
B, L = event_seq.shape
|
||
b = torch.arange(B, device=event_seq.device)
|
||
t0 = time_seq[b, t_ctx]
|
||
|
||
# Allow same-day events while excluding the context token itself.
|
||
# We rely on time-sorted sequences and select the FIRST valid future event by index.
|
||
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
||
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
|
||
idx_min = torch.where(
|
||
future, idxs, torch.full_like(idxs, L)).min(dim=1).values
|
||
|
||
has = idx_min < L
|
||
t_next = torch.where(has, idx_min, torch.zeros_like(idx_min))
|
||
|
||
t_next_time = time_seq[b, t_next]
|
||
dt_days = t_next_time - t0
|
||
dt_years = dt_days / DAYS_PER_YEAR
|
||
dt_years = torch.where(
|
||
has, dt_years, torch.full_like(dt_years, float("inf")))
|
||
|
||
cause_token = event_seq[b, t_next]
|
||
cause = (cause_token - 2).to(torch.long)
|
||
cause = torch.where(has, cause, torch.full_like(cause, -1))
|
||
|
||
return dt_years, cause
|
||
|
||
|
||
def multi_hot_ever_within_horizon(
|
||
event_seq: torch.Tensor,
|
||
time_seq: torch.Tensor,
|
||
t_ctx: torch.Tensor,
|
||
tau_years: float,
|
||
n_disease: int,
|
||
) -> torch.Tensor:
|
||
"""Binary labels: disease k occurs within tau after context (any occurrence)."""
|
||
B, L = event_seq.shape
|
||
b = torch.arange(B, device=event_seq.device)
|
||
t0 = time_seq[b, t_ctx]
|
||
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
||
|
||
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
||
# Include same-day events after context, exclude any token at/before context index.
|
||
in_window = (
|
||
(idxs > t_ctx.unsqueeze(1))
|
||
& (time_seq >= t0.unsqueeze(1))
|
||
& (time_seq <= t1.unsqueeze(1))
|
||
& (event_seq >= 2)
|
||
& (event_seq != 0)
|
||
)
|
||
|
||
if not in_window.any():
|
||
return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||
|
||
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||
|
||
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||
y[b_idx, disease_ids] = True
|
||
return y
|
||
|
||
|
||
def multi_hot_ever_after_context_anytime(
|
||
event_seq: torch.Tensor,
|
||
t_ctx: torch.Tensor,
|
||
n_disease: int,
|
||
) -> torch.Tensor:
|
||
"""Binary labels: disease k occurs ANYTIME after the prediction context.
|
||
|
||
This is Delphi2M-compatible for Task A case/control definition.
|
||
Same-day events are included as long as they occur after the context token index.
|
||
"""
|
||
B, L = event_seq.shape
|
||
idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1)
|
||
future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0)
|
||
|
||
y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device)
|
||
if not future.any():
|
||
return y
|
||
|
||
b_idx, t_idx = future.nonzero(as_tuple=True)
|
||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||
y[b_idx, disease_ids] = True
|
||
return y
|
||
|
||
|
||
def multi_hot_selected_causes_within_horizon(
|
||
event_seq: torch.Tensor,
|
||
time_seq: torch.Tensor,
|
||
t_ctx: torch.Tensor,
|
||
tau_years: float,
|
||
cause_ids: torch.Tensor,
|
||
n_disease: int,
|
||
) -> torch.Tensor:
|
||
"""Labels for selected causes only: does cause k occur within tau after context?"""
|
||
B, L = event_seq.shape
|
||
device = event_seq.device
|
||
b = torch.arange(B, device=device)
|
||
t0 = time_seq[b, t_ctx]
|
||
t1 = t0 + (tau_years * DAYS_PER_YEAR)
|
||
|
||
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
||
in_window = (
|
||
(idxs > t_ctx.unsqueeze(1))
|
||
& (time_seq >= t0.unsqueeze(1))
|
||
& (time_seq <= t1.unsqueeze(1))
|
||
& (event_seq >= 2)
|
||
& (event_seq != 0)
|
||
)
|
||
|
||
out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device)
|
||
if not in_window.any():
|
||
return out
|
||
|
||
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||
|
||
# Filter to selected causes via a boolean membership mask over the global disease space.
|
||
selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device)
|
||
selected[cause_ids] = True
|
||
keep = selected[disease_ids]
|
||
if not keep.any():
|
||
return out
|
||
|
||
b_idx = b_idx[keep]
|
||
disease_ids = disease_ids[keep]
|
||
|
||
# Map disease_id -> local index in cause_ids
|
||
# Build a lookup table (global disease space) where lookup[disease_id] = local_index
|
||
lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device)
|
||
lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device)
|
||
local = lookup[disease_ids]
|
||
out[b_idx, local] = True
|
||
return out
|
||
|
||
|
||
# ============================================================
|
||
# CIF conversion
|
||
# ============================================================
|
||
|
||
def cifs_from_exponential_logits(
|
||
logits: torch.Tensor,
|
||
taus: Sequence[float],
|
||
eps: float = 1e-6,
|
||
return_survival: bool = False,
|
||
) -> torch.Tensor:
|
||
"""Convert exponential cause-specific logits -> CIFs at taus.
|
||
|
||
logits: (B, K)
|
||
returns: (B, K, H) or (cif, survival) if return_survival
|
||
"""
|
||
hazards = F.softplus(logits) + eps
|
||
total = hazards.sum(dim=1, keepdim=True) # (B,1)
|
||
|
||
taus_t = torch.tensor(list(taus), device=logits.device,
|
||
dtype=hazards.dtype).view(1, 1, -1)
|
||
total_h = total.unsqueeze(-1) # (B,1,1)
|
||
|
||
# (1 - exp(-Lambda * tau))
|
||
one_minus_surv = 1.0 - torch.exp(-total_h * taus_t)
|
||
frac = hazards / torch.clamp(total, min=eps)
|
||
cif = frac.unsqueeze(-1) * one_minus_surv # (B,K,H)
|
||
# If total==0, set to 0
|
||
cif = torch.where(total_h > 0, cif, torch.zeros_like(cif))
|
||
|
||
if not return_survival:
|
||
return cif
|
||
|
||
survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H)
|
||
# Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors.
|
||
nonzero = (total.squeeze(1) > 0).unsqueeze(1)
|
||
survival = torch.where(nonzero, survival, torch.ones_like(survival))
|
||
return cif, survival
|
||
|
||
|
||
def cifs_from_discrete_time_logits(
|
||
logits: torch.Tensor,
|
||
bin_edges: Sequence[float],
|
||
taus: Sequence[float],
|
||
return_survival: bool = False,
|
||
) -> torch.Tensor:
|
||
"""Convert discrete-time CIF logits -> CIFs at taus.
|
||
|
||
logits: (B, K+1, n_bins+1)
|
||
bin_edges: len=n_bins+1 (including 0 and inf)
|
||
taus: subset of finite bin edges (recommended)
|
||
|
||
returns: (B, K, H) or (cif, survival) if return_survival
|
||
"""
|
||
if logits.ndim != 3:
|
||
raise ValueError("Expected logits shape (B, K+1, n_bins+1)")
|
||
|
||
B, K_plus_1, n_bins_plus_1 = logits.shape
|
||
K = K_plus_1 - 1
|
||
|
||
edges = [float(x) for x in bin_edges]
|
||
# drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf
|
||
finite_edges = [e for e in edges[1:] if math.isfinite(e)]
|
||
n_bins = len(finite_edges)
|
||
|
||
if n_bins_plus_1 != len(edges):
|
||
raise ValueError("logits last dim must match len(bin_edges)")
|
||
|
||
probs = torch.softmax(logits, dim=1) # (B, K+1, n_bins+1)
|
||
|
||
# use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot)
|
||
hazards = probs[:, :K, 1: 1 + n_bins] # (B,K,n_bins)
|
||
p_comp = probs[:, K, 1: 1 + n_bins] # (B,n_bins)
|
||
|
||
# survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v]
|
||
ones = torch.ones((B, 1), device=logits.device, dtype=probs.dtype)
|
||
cum = torch.cumprod(p_comp, dim=1)
|
||
s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # (B, n_bins)
|
||
|
||
cif_bins = torch.cumsum(s_prev.unsqueeze(
|
||
1) * hazards, dim=2) # (B,K,n_bins)
|
||
|
||
# Robust mapping from tau -> edge index (floating-point safe).
|
||
# taus are expected to align with bin edges, but may differ slightly due to parsing/serialization.
|
||
finite_edges_arr = np.asarray(finite_edges, dtype=float)
|
||
tau_to_idx: List[int] = []
|
||
for tau in taus:
|
||
tau_f = float(tau)
|
||
if not math.isfinite(tau_f):
|
||
raise ValueError("taus must be finite for discrete-time CIF")
|
||
diffs = np.abs(finite_edges_arr - tau_f)
|
||
j = int(np.argmin(diffs))
|
||
if diffs[j] > 1e-6:
|
||
raise ValueError(
|
||
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})"
|
||
)
|
||
tau_to_idx.append(j)
|
||
|
||
idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long)
|
||
cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H)
|
||
|
||
if not return_survival:
|
||
return cif
|
||
|
||
# Survival at each horizon = prod_{u <= idx[h]} p_comp[u]
|
||
survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v]
|
||
survival = survival_bins.index_select(dim=1, index=idx) # (B,H)
|
||
return cif, survival
|
||
|
||
|
||
def _normal_cdf_stable(z: torch.Tensor) -> torch.Tensor:
|
||
z = torch.clamp(z, -12.0, 12.0)
|
||
return 0.5 * (1.0 + torch.erf(z / math.sqrt(2.0)))
|
||
|
||
|
||
def cifs_from_lognormal_basis_binned_hazard_logits(
|
||
logits: torch.Tensor,
|
||
*,
|
||
centers: Sequence[float],
|
||
sigma: torch.Tensor,
|
||
bin_edges: Sequence[float],
|
||
taus: Sequence[float],
|
||
eps: float = 1e-8,
|
||
alpha_floor: float = 0.0,
|
||
return_survival: bool = False,
|
||
) -> torch.Tensor:
|
||
"""Convert Route-3 binned hazard logits -> CIFs at taus.
|
||
|
||
logits: (B, J, R) OR (B, J*R) OR (B, 1+J*R) (leading column ignored).
|
||
taus are expected to align with finite bin edges.
|
||
"""
|
||
if logits.ndim not in {2, 3}:
|
||
raise ValueError("logits must be 2D or 3D")
|
||
if sigma.ndim != 0:
|
||
raise ValueError("sigma must be a scalar tensor")
|
||
|
||
device = logits.device
|
||
dtype = logits.dtype
|
||
|
||
centers_t = torch.tensor([float(x)
|
||
for x in centers], device=device, dtype=dtype)
|
||
r = int(centers_t.numel())
|
||
if r <= 0:
|
||
raise ValueError("centers must be non-empty")
|
||
|
||
offset = 0
|
||
if logits.ndim == 3:
|
||
j = int(logits.shape[1])
|
||
if int(logits.shape[2]) != r:
|
||
raise ValueError(
|
||
f"logits.shape[2] must equal R={r}; got {int(logits.shape[2])}"
|
||
)
|
||
else:
|
||
d = int(logits.shape[1])
|
||
if d % r == 0:
|
||
jr = d
|
||
elif (d - 1) % r == 0:
|
||
offset = 1
|
||
jr = d - 1
|
||
else:
|
||
raise ValueError(
|
||
f"logits.shape[1] must be divisible by R={r} (or 1+J*R); got {d}")
|
||
|
||
j = jr // r
|
||
|
||
if j <= 0:
|
||
raise ValueError("Inferred J must be >= 1")
|
||
|
||
edges = [float(x) for x in bin_edges]
|
||
finite_edges = [e for e in edges[1:] if math.isfinite(e)]
|
||
n_bins = len(finite_edges)
|
||
if n_bins <= 0:
|
||
raise ValueError("bin_edges must contain at least one finite edge")
|
||
|
||
# Build finite bins [edges[k-1], edges[k]) for k=1..n_bins
|
||
left = torch.tensor(edges[:n_bins], device=device, dtype=dtype)
|
||
right = torch.tensor(edges[1:1 + n_bins], device=device, dtype=dtype)
|
||
|
||
# Stable t_min clamp (aligns with training loss rule).
|
||
t_min = 1e-12
|
||
if len(edges) >= 2 and math.isfinite(edges[1]) and edges[1] > 0:
|
||
t_min = edges[1] * 1e-6
|
||
t_min_t = torch.tensor(float(t_min), device=device, dtype=dtype)
|
||
|
||
left_is_zero = left <= 0
|
||
left_clamped = torch.clamp(left, min=t_min_t)
|
||
log_left = torch.log(left_clamped)
|
||
right_clamped = torch.clamp(right, min=t_min_t)
|
||
log_right = torch.log(right_clamped)
|
||
|
||
sigma_c = sigma.to(device=device, dtype=dtype)
|
||
z_left = (log_left.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c
|
||
z_right = (log_right.unsqueeze(-1) - centers_t.unsqueeze(0)) / sigma_c
|
||
|
||
cdf_left = _normal_cdf_stable(z_left)
|
||
if left_is_zero.any():
|
||
cdf_left = torch.where(left_is_zero.unsqueeze(-1),
|
||
torch.zeros_like(cdf_left), cdf_left)
|
||
cdf_right = _normal_cdf_stable(z_right)
|
||
delta_basis = torch.clamp(cdf_right - cdf_left, min=0.0) # (n_bins, R)
|
||
|
||
if logits.ndim == 3:
|
||
alpha = F.softplus(logits) + float(alpha_floor) # (B,J,R)
|
||
else:
|
||
logits_used = logits[:, offset:]
|
||
alpha = (F.softplus(logits_used) + float(alpha_floor)
|
||
).view(logits.size(0), j, r) # (B,J,R)
|
||
|
||
h_jk = torch.einsum("bjr,kr->bjk", alpha, delta_basis) # (B,J,n_bins)
|
||
h_k = h_jk.sum(dim=1) # (B,n_bins)
|
||
|
||
h_k = torch.clamp(h_k, min=eps)
|
||
h_jk = torch.clamp(h_jk, min=eps)
|
||
|
||
p_comp = torch.exp(-h_k) # (B,n_bins)
|
||
one_minus = -torch.expm1(-h_k) # (B,n_bins) = 1-exp(-H)
|
||
ratio = h_jk / torch.clamp(h_k.unsqueeze(1), min=eps)
|
||
p_event = one_minus.unsqueeze(1) * ratio # (B,J,n_bins)
|
||
|
||
ones = torch.ones((alpha.size(0), 1), device=device, dtype=dtype)
|
||
cum = torch.cumprod(p_comp, dim=1) # survival after each bin
|
||
s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # survival before each bin
|
||
|
||
cif_bins = torch.cumsum(s_prev.unsqueeze(
|
||
1) * p_event, dim=2) # (B,J,n_bins)
|
||
|
||
finite_edges_arr = np.asarray(finite_edges, dtype=float)
|
||
tau_to_idx: List[int] = []
|
||
for tau in taus:
|
||
tau_f = float(tau)
|
||
if not math.isfinite(tau_f):
|
||
raise ValueError("taus must be finite for discrete-time CIF")
|
||
diffs = np.abs(finite_edges_arr - tau_f)
|
||
idx0 = int(np.argmin(diffs))
|
||
if diffs[idx0] > 1e-6:
|
||
raise ValueError(
|
||
f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[idx0]})"
|
||
)
|
||
tau_to_idx.append(idx0)
|
||
|
||
idx = torch.tensor(tau_to_idx, device=device, dtype=torch.long)
|
||
cif = cif_bins.index_select(dim=2, index=idx) # (B,J,H)
|
||
|
||
if not return_survival:
|
||
return cif
|
||
|
||
survival = cum.index_select(dim=1, index=idx) # (B,H)
|
||
return cif, survival
|
||
|
||
|
||
# ============================================================
|
||
# CIF integrity checks
|
||
# ============================================================
|
||
|
||
def check_cif_integrity(
|
||
cause_cif: np.ndarray,
|
||
horizons: Sequence[float],
|
||
*,
|
||
tol: float = 1e-6,
|
||
name: str = "",
|
||
strict: bool = False,
|
||
survival: Optional[np.ndarray] = None,
|
||
) -> Tuple[bool, List[str]]:
|
||
"""Run basic sanity checks on CIF arrays.
|
||
|
||
Args:
|
||
cause_cif: (N, K, H)
|
||
horizons: length H
|
||
tol: tolerance for inequalities
|
||
name: model name for messages
|
||
strict: if True, raise ValueError on first failure
|
||
survival: optional (N, H) survival values at the same horizons
|
||
|
||
Returns:
|
||
(integrity_ok, notes)
|
||
"""
|
||
notes: List[str] = []
|
||
model_tag = f"[{name}] " if name else ""
|
||
|
||
def _fail(msg: str) -> None:
|
||
full = model_tag + msg
|
||
if strict:
|
||
raise ValueError(full)
|
||
print("WARNING:", full)
|
||
notes.append(msg)
|
||
|
||
cif = np.asarray(cause_cif)
|
||
if cif.ndim != 3:
|
||
_fail(f"integrity: expected cause_cif ndim=3, got {cif.ndim}")
|
||
return False, notes
|
||
N, K, H = cif.shape
|
||
if H != len(horizons):
|
||
_fail(
|
||
f"integrity: horizon length mismatch (H={H}, len(horizons)={len(horizons)})")
|
||
|
||
# (5) Finite
|
||
if not np.isfinite(cif).all():
|
||
_fail("integrity: non-finite values (NaN/Inf) in cause_cif")
|
||
|
||
# (1) Range
|
||
cmin = float(np.nanmin(cif))
|
||
cmax = float(np.nanmax(cif))
|
||
if cmin < -tol:
|
||
_fail(f"integrity: range min={cmin} < -tol={-tol}")
|
||
if cmax > 1.0 + tol:
|
||
_fail(f"integrity: range max={cmax} > 1+tol={1.0+tol}")
|
||
|
||
# (2) Monotonicity in horizons (per n,k)
|
||
diffs = np.diff(cif, axis=2)
|
||
if diffs.size > 0:
|
||
if np.nanmin(diffs) < -tol:
|
||
_fail("integrity: monotonicity violated (found negative diff along horizons)")
|
||
|
||
# (3) Probability mass: sum_k CIF <= 1
|
||
mass = np.sum(cif, axis=1) # (N,H)
|
||
mass_max = float(np.nanmax(mass))
|
||
if mass_max > 1.0 + tol:
|
||
_fail(f"integrity: probability mass exceeds 1 (max sum_k={mass_max})")
|
||
|
||
# (4) Conservation with survival, if provided
|
||
if survival is None:
|
||
warn = "integrity: survival not provided; skipping conservation check"
|
||
if strict:
|
||
# still skip (requested behavior), but keep message for context
|
||
notes.append(warn)
|
||
else:
|
||
print("WARNING:", model_tag + warn)
|
||
notes.append(warn)
|
||
else:
|
||
s = np.asarray(survival, dtype=float)
|
||
if s.shape != (N, H):
|
||
_fail(
|
||
f"integrity: survival shape mismatch (got {s.shape}, expected {(N, H)})")
|
||
else:
|
||
recon = 1.0 - s
|
||
err = np.abs(recon - mass)
|
||
# Discrete-time should be very tight; exponential may accumulate slightly more numerical error.
|
||
tol_cons = max(float(tol), 1e-4)
|
||
if float(np.nanmax(err)) > tol_cons:
|
||
_fail(
|
||
f"integrity: conservation violated (max |(1-surv)-sum_cif|={float(np.nanmax(err))}, tol={tol_cons})")
|
||
|
||
ok = len([n for n in notes if not n.endswith(
|
||
"skipping conservation check")]) == 0
|
||
return ok, notes
|
||
|
||
|
||
# ============================================================
|
||
# Metrics
|
||
# ============================================================
|
||
|
||
# --- Rank-based ROC AUC (ties handled via midranks) ---
|
||
|
||
def compute_midrank(x: np.ndarray) -> np.ndarray:
|
||
"""Vectorized midrank computation (ties -> average ranks)."""
|
||
x = np.asarray(x, dtype=float)
|
||
n = int(x.shape[0])
|
||
if n == 0:
|
||
return np.asarray([], dtype=float)
|
||
|
||
order = np.argsort(x, kind="mergesort")
|
||
z = x[order]
|
||
|
||
# Find tie groups in sorted order.
|
||
diff = np.diff(z)
|
||
# boundaries includes 0 and n
|
||
boundaries = np.concatenate(
|
||
[np.array([0], dtype=int), np.nonzero(diff != 0)
|
||
[0] + 1, np.array([n], dtype=int)]
|
||
)
|
||
starts = boundaries[:-1]
|
||
ends = boundaries[1:]
|
||
lens = ends - starts
|
||
|
||
# Midrank for each group in 1-based rank space.
|
||
mids = 0.5 * (starts + ends - 1) + 1.0
|
||
t_sorted = np.repeat(mids, lens).astype(float, copy=False)
|
||
|
||
out = np.empty(n, dtype=float)
|
||
out[order] = t_sorted
|
||
return out
|
||
|
||
|
||
def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||
"""Rank-based ROC AUC via Mann–Whitney U statistic (ties handled by midranks).
|
||
|
||
Returns NaN for degenerate labels.
|
||
"""
|
||
y = np.asarray(y_true, dtype=int)
|
||
s = np.asarray(y_score, dtype=float)
|
||
if y.ndim != 1 or s.ndim != 1 or y.shape[0] != s.shape[0]:
|
||
raise ValueError("roc_auc_rank expects 1D arrays of equal length")
|
||
m = int(np.sum(y == 1))
|
||
n = int(np.sum(y == 0))
|
||
if m == 0 or n == 0:
|
||
return float("nan")
|
||
|
||
ranks = compute_midrank(s)
|
||
sum_pos = float(np.sum(ranks[y == 1]))
|
||
auc = (sum_pos - m * (m + 1) / 2.0) / (m * n)
|
||
return float(auc)
|
||
|
||
|
||
def brier_score(p: np.ndarray, y: np.ndarray) -> float:
|
||
p = np.asarray(p, dtype=float)
|
||
y = np.asarray(y, dtype=float)
|
||
return float(np.mean((p - y) ** 2))
|
||
|
||
|
||
def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]:
|
||
bins, ici = _calibration_bins_and_ici(
|
||
p, y, n_bins=int(n_bins), return_bins=True)
|
||
return {"bins": bins, "ici": float(ici)}
|
||
|
||
|
||
def calibration_ici_only(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> float:
|
||
"""Fast ICI only (no per-bin point export)."""
|
||
_, ici = _calibration_bins_and_ici(
|
||
p, y, n_bins=int(n_bins), return_bins=False)
|
||
return float(ici)
|
||
|
||
|
||
def _calibration_bins_and_ici(
|
||
p: np.ndarray,
|
||
y: np.ndarray,
|
||
*,
|
||
n_bins: int,
|
||
return_bins: bool,
|
||
) -> Tuple[List[Dict[str, Any]], float]:
|
||
"""Vectorized quantile binning for calibration + ICI."""
|
||
p = np.asarray(p, dtype=float)
|
||
y = np.asarray(y, dtype=float)
|
||
if p.size == 0:
|
||
return ([], float("nan")) if return_bins else ([], float("nan"))
|
||
|
||
q = np.linspace(0.0, 1.0, int(n_bins) + 1)
|
||
edges = np.quantile(p, q)
|
||
edges = np.asarray(edges, dtype=float)
|
||
edges[0] = -np.inf
|
||
edges[-1] = np.inf
|
||
|
||
# Bin assignment: i if edges[i] < p <= edges[i+1]
|
||
bin_idx = np.searchsorted(edges, p, side="right") - 1
|
||
bin_idx = np.clip(bin_idx, 0, int(n_bins) - 1)
|
||
|
||
counts = np.bincount(bin_idx, minlength=int(n_bins)).astype(float)
|
||
sum_p = np.bincount(bin_idx, weights=p,
|
||
minlength=int(n_bins)).astype(float)
|
||
sum_y = np.bincount(bin_idx, weights=y,
|
||
minlength=int(n_bins)).astype(float)
|
||
|
||
nonempty = counts > 0
|
||
if not np.any(nonempty):
|
||
return ([], float("nan")) if return_bins else ([], float("nan"))
|
||
|
||
p_mean = np.zeros(int(n_bins), dtype=float)
|
||
y_mean = np.zeros(int(n_bins), dtype=float)
|
||
p_mean[nonempty] = sum_p[nonempty] / counts[nonempty]
|
||
y_mean[nonempty] = sum_y[nonempty] / counts[nonempty]
|
||
|
||
diffs = np.abs(p_mean[nonempty] - y_mean[nonempty])
|
||
ici = float(np.mean(diffs)) if diffs.size else float("nan")
|
||
|
||
if not return_bins:
|
||
return [], ici
|
||
|
||
bins: List[Dict[str, Any]] = []
|
||
idxs = np.nonzero(nonempty)[0]
|
||
for i in idxs.tolist():
|
||
bins.append(
|
||
{
|
||
"bin": int(i),
|
||
"p_mean": float(p_mean[i]),
|
||
"y_mean": float(y_mean[i]),
|
||
"n": int(counts[i]),
|
||
}
|
||
)
|
||
return bins, ici
|
||
|
||
|
||
def _progress_line(done: int, total: int, prefix: str = "") -> str:
|
||
total_i = max(1, int(total))
|
||
done_i = max(0, min(int(done), total_i))
|
||
width = 28
|
||
frac = done_i / total_i
|
||
filled = int(round(width * frac))
|
||
bar = "#" * filled + "-" * (width - filled)
|
||
pct = 100.0 * frac
|
||
return f"{prefix}[{bar}] {done_i}/{total_i} ({pct:5.1f}%)"
|
||
|
||
|
||
def _should_show_progress(mode: str) -> bool:
|
||
m = str(mode).strip().lower()
|
||
if m in {"0", "false", "no", "none", "off"}:
|
||
return False
|
||
# Default: show if interactive.
|
||
if m in {"auto", "1", "true", "yes", "on", "bar"}:
|
||
try:
|
||
return bool(sys.stdout.isatty())
|
||
except Exception:
|
||
return True
|
||
return True
|
||
|
||
|
||
def _safe_float(x: Any, default: float = float("nan")) -> float:
|
||
try:
|
||
return float(x)
|
||
except Exception:
|
||
return float(default)
|
||
|
||
|
||
def _ensure_dir(path: str) -> None:
|
||
os.makedirs(path, exist_ok=True)
|
||
|
||
|
||
def load_cause_names(path: str = "labels.csv") -> Dict[int, str]:
|
||
"""Load 0-based cause_id -> name mapping.
|
||
|
||
labels.csv is assumed to be one label per line, in disease-id order.
|
||
"""
|
||
if not os.path.exists(path):
|
||
return {}
|
||
mapping: Dict[int, str] = {}
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
for i, line in enumerate(f):
|
||
name = line.strip()
|
||
if name:
|
||
mapping[int(i)] = name
|
||
return mapping
|
||
|
||
|
||
def pick_focus_causes(
|
||
*,
|
||
counts_within_tau: Optional[np.ndarray],
|
||
n_disease: int,
|
||
death_cause_id: int = DEFAULT_DEATH_CAUSE_ID,
|
||
k: int = 5,
|
||
) -> List[int]:
|
||
"""Pick focus causes for user-facing evaluation.
|
||
|
||
Rule:
|
||
1) Always include death_cause_id first.
|
||
2) Then add K additional causes by descending event count if available.
|
||
If counts_within_tau is None, fall back to descending cause_id coverage proxy.
|
||
|
||
Notes:
|
||
- counts_within_tau is expected to be shape (n_disease,).
|
||
- Deterministic: ties broken by smaller cause id.
|
||
"""
|
||
n_disease_i = int(n_disease)
|
||
if death_cause_id < 0 or death_cause_id >= n_disease_i:
|
||
print(
|
||
f"WARNING: death_cause_id={death_cause_id} out of range (n_disease={n_disease_i}); "
|
||
"it will be omitted from focus causes."
|
||
)
|
||
focus: List[int] = []
|
||
else:
|
||
focus = [int(death_cause_id)]
|
||
|
||
candidates = [i for i in range(n_disease_i) if i != int(death_cause_id)]
|
||
|
||
if counts_within_tau is not None:
|
||
c = np.asarray(counts_within_tau).astype(float)
|
||
if c.shape[0] != n_disease_i:
|
||
print(
|
||
"WARNING: counts_within_tau length mismatch; falling back to coverage proxy ordering."
|
||
)
|
||
counts_within_tau = None
|
||
else:
|
||
# Sort by (-count, cause_id)
|
||
order = sorted(candidates, key=lambda i: (-float(c[i]), int(i)))
|
||
order = [i for i in order if float(c[i]) > 0]
|
||
focus.extend([int(i) for i in order[: int(k)]])
|
||
|
||
if counts_within_tau is None:
|
||
# Fallback: deterministic coverage proxy (descending id, excluding death), then take K.
|
||
# (Real coverage requires data; this path is mostly for robustness.)
|
||
order = sorted(candidates, key=lambda i: (-int(i)))
|
||
focus.extend([int(i) for i in order[: int(k)]])
|
||
|
||
# De-dup while preserving order
|
||
seen = set()
|
||
out: List[int] = []
|
||
for cid in focus:
|
||
if cid not in seen:
|
||
out.append(cid)
|
||
seen.add(cid)
|
||
return out
|
||
|
||
|
||
def write_simple_csv(path: str, fieldnames: List[str], rows: List[Dict[str, Any]]) -> None:
|
||
_ensure_dir(os.path.dirname(os.path.abspath(path)) or ".")
|
||
with open(path, "w", newline="", encoding="utf-8") as f:
|
||
w = csv.DictWriter(f, fieldnames=fieldnames)
|
||
w.writeheader()
|
||
for r in rows:
|
||
w.writerow(r)
|
||
|
||
|
||
def _sex_slices(sex: Optional[np.ndarray]) -> List[Tuple[str, Optional[np.ndarray]]]:
|
||
"""Return list of (sex_label, mask) slices including an 'all' slice.
|
||
|
||
If sex is missing, returns only ('all', None).
|
||
"""
|
||
out: List[Tuple[str, Optional[np.ndarray]]] = [("all", None)]
|
||
if sex is None:
|
||
return out
|
||
s = np.asarray(sex)
|
||
if s.ndim != 1:
|
||
return out
|
||
for val in [0, 1]:
|
||
m = (s == val)
|
||
if int(np.sum(m)) > 0:
|
||
out.append((str(val), m))
|
||
return out
|
||
|
||
|
||
def _quantile_edges(p: np.ndarray, q: int) -> np.ndarray:
|
||
edges = np.quantile(p, np.linspace(0.0, 1.0, int(q) + 1))
|
||
edges = np.asarray(edges, dtype=float)
|
||
edges[0] = -np.inf
|
||
edges[-1] = np.inf
|
||
return edges
|
||
|
||
|
||
def compute_risk_stratification_bins(
|
||
p: np.ndarray,
|
||
y: np.ndarray,
|
||
*,
|
||
q_default: int = 10,
|
||
) -> Tuple[int, List[Dict[str, Any]], Dict[str, Any]]:
|
||
"""Compute quantile-based risk strata and a compact summary."""
|
||
p = np.asarray(p, dtype=float)
|
||
y = np.asarray(y, dtype=float)
|
||
n = int(p.shape[0])
|
||
if n == 0:
|
||
return 0, [], {
|
||
"y_overall": float("nan"),
|
||
"top_decile_y_rate": float("nan"),
|
||
"bottom_half_y_rate": float("nan"),
|
||
"lift_top10_vs_bottom50": float("nan"),
|
||
"slope_pred_vs_obs": float("nan"),
|
||
}
|
||
|
||
# Choose quantiles robustly.
|
||
q = int(q_default)
|
||
if n < 200:
|
||
q = 5
|
||
|
||
edges = _quantile_edges(p, q)
|
||
y_overall = float(np.mean(y))
|
||
bin_rows: List[Dict[str, Any]] = []
|
||
p_means: List[float] = []
|
||
y_rates: List[float] = []
|
||
n_bins: List[int] = []
|
||
|
||
for i in range(q):
|
||
mask = (p > edges[i]) & (p <= edges[i + 1])
|
||
nb = int(np.sum(mask))
|
||
if nb == 0:
|
||
# Keep the row for consistent plotting; set NaNs.
|
||
bin_rows.append(
|
||
{
|
||
"q": int(i + 1),
|
||
"n_bin": 0,
|
||
"p_mean": float("nan"),
|
||
"y_rate": float("nan"),
|
||
"y_overall": y_overall,
|
||
"lift_vs_overall": float("nan"),
|
||
}
|
||
)
|
||
continue
|
||
p_mean = float(np.mean(p[mask]))
|
||
y_rate = float(np.mean(y[mask]))
|
||
lift = float(y_rate / y_overall) if y_overall > 0 else float("nan")
|
||
bin_rows.append(
|
||
{
|
||
"q": int(i + 1),
|
||
"n_bin": nb,
|
||
"p_mean": p_mean,
|
||
"y_rate": y_rate,
|
||
"y_overall": y_overall,
|
||
"lift_vs_overall": lift,
|
||
}
|
||
)
|
||
p_means.append(p_mean)
|
||
y_rates.append(y_rate)
|
||
n_bins.append(nb)
|
||
|
||
# Summary
|
||
top_mask = (p > edges[q - 1]) & (p <= edges[q])
|
||
bot_half_mask = (p > edges[0]) & (p <= edges[q // 2])
|
||
top_y = float(np.mean(y[top_mask])) if int(
|
||
np.sum(top_mask)) > 0 else float("nan")
|
||
bot_y = float(np.mean(y[bot_half_mask])) if int(
|
||
np.sum(bot_half_mask)) > 0 else float("nan")
|
||
lift_top_vs_bottom = float(top_y / bot_y) if (np.isfinite(top_y)
|
||
and np.isfinite(bot_y) and bot_y > 0) else float("nan")
|
||
|
||
slope = float("nan")
|
||
if len(p_means) >= 2:
|
||
# Weighted least squares slope of y_rate ~ p_mean.
|
||
x = np.asarray(p_means, dtype=float)
|
||
yy = np.asarray(y_rates, dtype=float)
|
||
w = np.asarray(n_bins, dtype=float)
|
||
xm = float(np.average(x, weights=w))
|
||
ym = float(np.average(yy, weights=w))
|
||
denom = float(np.sum(w * (x - xm) ** 2))
|
||
if denom > 0:
|
||
slope = float(np.sum(w * (x - xm) * (yy - ym)) / denom)
|
||
|
||
summary = {
|
||
"y_overall": y_overall,
|
||
"top_decile_y_rate": top_y,
|
||
"bottom_half_y_rate": bot_y,
|
||
"lift_top10_vs_bottom50": lift_top_vs_bottom,
|
||
"slope_pred_vs_obs": slope,
|
||
}
|
||
return q, bin_rows, summary
|
||
|
||
|
||
def compute_capture_points(
|
||
p: np.ndarray,
|
||
y: np.ndarray,
|
||
k_pcts: Sequence[int],
|
||
) -> List[Dict[str, Any]]:
|
||
p = np.asarray(p, dtype=float)
|
||
y = np.asarray(y, dtype=float)
|
||
n = int(p.shape[0])
|
||
if n == 0:
|
||
return []
|
||
order = np.argsort(-p)
|
||
y_sorted = y[order]
|
||
events_total = float(np.sum(y_sorted))
|
||
|
||
rows: List[Dict[str, Any]] = []
|
||
for k in k_pcts:
|
||
kf = float(k)
|
||
n_targeted = int(math.ceil(n * kf / 100.0))
|
||
n_targeted = max(1, min(n_targeted, n))
|
||
events_targeted = float(np.sum(y_sorted[:n_targeted]))
|
||
capture = float(events_targeted /
|
||
events_total) if events_total > 0 else float("nan")
|
||
precision = float(events_targeted / float(n_targeted))
|
||
rows.append(
|
||
{
|
||
"k_pct": int(k),
|
||
"n_targeted": int(n_targeted),
|
||
"events_targeted": float(events_targeted),
|
||
"events_total": float(events_total),
|
||
"event_capture_rate": capture,
|
||
"precision_in_targeted": precision,
|
||
}
|
||
)
|
||
return rows
|
||
|
||
|
||
def compute_event_rate_at_topk_causes(
|
||
p_tau: np.ndarray,
|
||
y_tau: np.ndarray,
|
||
topk_list: Sequence[int],
|
||
) -> List[Dict[str, Any]]:
|
||
"""Compute Event Rate@K for cross-cause prioritization.
|
||
|
||
For each individual, rank causes by predicted risk p_tau at a fixed horizon.
|
||
For each K, select top-K causes and compute the fraction that occur within the horizon.
|
||
|
||
Args:
|
||
p_tau: (N, K) predicted CIFs at a fixed horizon
|
||
y_tau: (N, K) binary labels (0/1) whether cause occurs within the horizon
|
||
topk_list: list of K values to evaluate
|
||
|
||
Returns:
|
||
List of rows with:
|
||
- topk
|
||
- event_rate_mean / event_rate_median
|
||
- recall_mean / recall_median (averaged over individuals with >=1 true cause)
|
||
- n_total / n_valid_recall
|
||
"""
|
||
p = np.asarray(p_tau, dtype=float)
|
||
y = np.asarray(y_tau, dtype=float)
|
||
if p.ndim != 2 or y.ndim != 2 or p.shape != y.shape:
|
||
raise ValueError(
|
||
"compute_event_rate_at_topk_causes expects (N,K) arrays of equal shape")
|
||
|
||
n, k_total = p.shape
|
||
if n == 0 or k_total == 0:
|
||
out: List[Dict[str, Any]] = []
|
||
for kk in topk_list:
|
||
out.append(
|
||
{
|
||
"topk": int(max(1, int(kk))),
|
||
"event_rate_mean": float("nan"),
|
||
"event_rate_median": float("nan"),
|
||
"recall_mean": float("nan"),
|
||
"recall_median": float("nan"),
|
||
"n_total": int(n),
|
||
"n_valid_recall": 0,
|
||
}
|
||
)
|
||
return out
|
||
|
||
# Sanitize K list.
|
||
topks = sorted({int(x) for x in topk_list if int(x) > 0})
|
||
if not topks:
|
||
return []
|
||
|
||
max_k = min(int(max(topks)), int(k_total))
|
||
if max_k <= 0:
|
||
return []
|
||
|
||
# Efficient: get top max_k causes per individual, then sort within those.
|
||
part = np.argpartition(-p, kth=max_k - 1, axis=1)[:, :max_k] # (N, max_k)
|
||
p_part = np.take_along_axis(p, part, axis=1)
|
||
order = np.argsort(-p_part, axis=1)
|
||
top_sorted = np.take_along_axis(part, order, axis=1) # (N, max_k)
|
||
|
||
out_rows: List[Dict[str, Any]] = []
|
||
for kk in topks:
|
||
kk_eff = min(int(kk), int(k_total))
|
||
idx = top_sorted[:, :kk_eff]
|
||
y_sel = np.take_along_axis(y, idx, axis=1)
|
||
# Selected true causes per person
|
||
hit = np.sum(y_sel, axis=1)
|
||
# Precision-like: fraction of selected causes that occur
|
||
per_person = hit / \
|
||
float(kk_eff) if kk_eff > 0 else np.full((n,), np.nan)
|
||
|
||
# Recall@K: fraction of true causes covered by top-K (undefined when no true cause)
|
||
g = np.sum(y, axis=1)
|
||
valid = g > 0
|
||
recall = np.full((n,), np.nan, dtype=float)
|
||
recall[valid] = hit[valid] / g[valid]
|
||
out_rows.append(
|
||
{
|
||
"topk": int(kk_eff),
|
||
"event_rate_mean": float(np.mean(per_person)) if per_person.size else float("nan"),
|
||
"event_rate_median": float(np.median(per_person)) if per_person.size else float("nan"),
|
||
"recall_mean": float(np.nanmean(recall)) if int(np.sum(valid)) > 0 else float("nan"),
|
||
"recall_median": float(np.nanmedian(recall)) if int(np.sum(valid)) > 0 else float("nan"),
|
||
"n_total": int(n),
|
||
"n_valid_recall": int(np.sum(valid)),
|
||
}
|
||
)
|
||
return out_rows
|
||
|
||
|
||
def compute_random_ranking_baseline_topk(
|
||
y_tau: np.ndarray,
|
||
topk_list: Sequence[int],
|
||
*,
|
||
z: float = 1.645,
|
||
) -> List[Dict[str, Any]]:
|
||
"""Random ranking baseline for Event Rate@K and Recall@K.
|
||
|
||
Baseline definition:
|
||
- For each individual, pick K causes uniformly at random without replacement.
|
||
- EventRate@K = (# selected causes that occur) / K.
|
||
- Recall@K = (# selected causes that occur) / (# causes that occur), averaged over individuals with >=1 true cause.
|
||
|
||
This function computes the expected baseline mean and an approximate 5-95% range
|
||
for the population mean using a normal approximation of the hypergeometric variance.
|
||
|
||
Args:
|
||
y_tau: (N, K_total) binary labels
|
||
topk_list: K values
|
||
z: z-score for the central interval; z=1.645 corresponds to ~90% (5-95%)
|
||
|
||
Returns:
|
||
Rows with baseline means and p05/p95 for both metrics.
|
||
"""
|
||
y = np.asarray(y_tau, dtype=float)
|
||
if y.ndim != 2:
|
||
raise ValueError(
|
||
"compute_random_ranking_baseline_topk expects y_tau with shape (N,K)")
|
||
|
||
n, k_total = y.shape
|
||
topks = sorted({int(x) for x in topk_list if int(x) > 0})
|
||
if not topks:
|
||
return []
|
||
|
||
g = np.sum(y, axis=1) # (N,)
|
||
valid = g > 0
|
||
n_valid = int(np.sum(valid))
|
||
|
||
out: List[Dict[str, Any]] = []
|
||
for kk in topks:
|
||
kk_eff = min(int(kk), int(k_total)) if k_total > 0 else int(kk)
|
||
if n == 0 or k_total == 0 or kk_eff <= 0:
|
||
out.append(
|
||
{
|
||
"topk": int(max(1, kk_eff)),
|
||
"baseline_event_rate_mean": float("nan"),
|
||
"baseline_event_rate_p05": float("nan"),
|
||
"baseline_event_rate_p95": float("nan"),
|
||
"baseline_recall_mean": float("nan"),
|
||
"baseline_recall_p05": float("nan"),
|
||
"baseline_recall_p95": float("nan"),
|
||
"n_total": int(n),
|
||
"n_valid_recall": int(n_valid),
|
||
"k_total": int(k_total),
|
||
"baseline_method": "random_ranking_hypergeometric_normal_approx",
|
||
}
|
||
)
|
||
continue
|
||
|
||
# Expected EventRate@K per person is E[X]/K = (K * (g/K_total))/K = g/K_total.
|
||
er_mean = float(np.mean(g / float(k_total)))
|
||
|
||
# Variance of hypergeometric count X:
|
||
# Var(X) = K * p * (1-p) * ((K_total - K)/(K_total - 1)), where p=g/K_total.
|
||
if k_total > 1 and kk_eff < k_total:
|
||
p = g / float(k_total)
|
||
finite_corr = (float(k_total - kk_eff) / float(k_total - 1))
|
||
var_x = float(kk_eff) * p * (1.0 - p) * finite_corr
|
||
else:
|
||
var_x = np.zeros_like(g, dtype=float)
|
||
|
||
var_er = var_x / (float(kk_eff) ** 2)
|
||
se_er_mean = float(np.sqrt(np.sum(var_er))) / float(max(1, n))
|
||
er_p05 = float(np.clip(er_mean - z * se_er_mean, 0.0, 1.0))
|
||
er_p95 = float(np.clip(er_mean + z * se_er_mean, 0.0, 1.0))
|
||
|
||
# Expected Recall@K for individuals with g>0 is K/K_total (clipped).
|
||
rec_mean = float(min(float(kk_eff) / float(k_total), 1.0))
|
||
if n_valid > 0:
|
||
var_rec = np.zeros_like(g, dtype=float)
|
||
gv = g[valid]
|
||
var_xv = var_x[valid]
|
||
# Var( X / g ) = Var(X) / g^2 (approx; g is fixed per individual)
|
||
var_rec_v = var_xv / (gv ** 2)
|
||
se_rec_mean = float(np.sqrt(np.sum(var_rec_v))) / float(n_valid)
|
||
rec_p05 = float(np.clip(rec_mean - z * se_rec_mean, 0.0, 1.0))
|
||
rec_p95 = float(np.clip(rec_mean + z * se_rec_mean, 0.0, 1.0))
|
||
else:
|
||
rec_p05 = float("nan")
|
||
rec_p95 = float("nan")
|
||
|
||
out.append(
|
||
{
|
||
"topk": int(kk_eff),
|
||
"baseline_event_rate_mean": er_mean,
|
||
"baseline_event_rate_p05": er_p05,
|
||
"baseline_event_rate_p95": er_p95,
|
||
"baseline_recall_mean": rec_mean,
|
||
"baseline_recall_p05": float(rec_p05),
|
||
"baseline_recall_p95": float(rec_p95),
|
||
"n_total": int(n),
|
||
"n_valid_recall": int(n_valid),
|
||
"k_total": int(k_total),
|
||
"baseline_method": "random_ranking_hypergeometric_normal_approx",
|
||
}
|
||
)
|
||
return out
|
||
|
||
|
||
def count_occurs_within_horizon(
|
||
loader: DataLoader,
|
||
offset_years: float,
|
||
tau_years: float,
|
||
n_disease: int,
|
||
device: str,
|
||
) -> Tuple[np.ndarray, int]:
|
||
"""Count per-person occurrence within tau after the prediction context.
|
||
|
||
Returns counts[k] = number of individuals with disease k at least once in (t_ctx, t_ctx+tau].
|
||
"""
|
||
counts = torch.zeros((n_disease,), dtype=torch.long, device=device)
|
||
n_total_eval = 0
|
||
for batch in loader:
|
||
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
||
event_seq = event_seq.to(device)
|
||
time_seq = time_seq.to(device)
|
||
|
||
keep, t_ctx, _ = select_context_indices(
|
||
event_seq, time_seq, offset_years)
|
||
if not keep.any():
|
||
continue
|
||
|
||
n_total_eval += int(keep.sum().item())
|
||
event_seq = event_seq[keep]
|
||
time_seq = time_seq[keep]
|
||
t_ctx = t_ctx[keep]
|
||
|
||
B, L = event_seq.shape
|
||
b = torch.arange(B, device=device)
|
||
t0 = time_seq[b, t_ctx]
|
||
t1 = t0 + (float(tau_years) * DAYS_PER_YEAR)
|
||
|
||
idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
|
||
in_window = (
|
||
(idxs > t_ctx.unsqueeze(1))
|
||
& (time_seq >= t0.unsqueeze(1))
|
||
& (time_seq <= t1.unsqueeze(1))
|
||
& (event_seq >= 2)
|
||
& (event_seq != 0)
|
||
)
|
||
if not in_window.any():
|
||
continue
|
||
|
||
b_idx, t_idx = in_window.nonzero(as_tuple=True)
|
||
disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long)
|
||
|
||
# unique per (person, disease) to count per-person within-window occurrence
|
||
key = b_idx.to(torch.long) * int(n_disease) + disease_ids
|
||
uniq = torch.unique(key)
|
||
uniq_disease = uniq % int(n_disease)
|
||
counts.scatter_add_(0, uniq_disease, torch.ones_like(
|
||
uniq_disease, dtype=torch.long))
|
||
|
||
return counts.detach().cpu().numpy(), int(n_total_eval)
|
||
|
||
|
||
# ============================================================
|
||
# Evaluation core
|
||
# ============================================================
|
||
|
||
def instantiate_model_and_head(
|
||
cfg: Dict[str, Any],
|
||
dataset: HealthDataset,
|
||
device: str,
|
||
checkpoint_path: str = "",
|
||
) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float], Dict[str, Any]]:
|
||
model_type = str(cfg["model_type"])
|
||
loss_type = str(cfg["loss_type"])
|
||
|
||
loss_params: Dict[str, Any] = {}
|
||
|
||
if loss_type == "exponential":
|
||
out_dims = [dataset.n_disease]
|
||
elif loss_type == "discrete_time_cif":
|
||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||
out_dims = [dataset.n_disease + 1, len(bin_edges)]
|
||
elif loss_type == "lognormal_basis_binned_hazard_cif":
|
||
centers = cfg.get("lognormal_centers", None)
|
||
if centers is None:
|
||
centers = cfg.get("centers", None)
|
||
if not isinstance(centers, list) or len(centers) == 0:
|
||
raise ValueError(
|
||
"lognormal_basis_binned_hazard_cif requires 'lognormal_centers' (list of mu_r in log-time) in train_config.json"
|
||
)
|
||
r = len(centers)
|
||
desired_total = int(dataset.n_disease) * int(r)
|
||
legacy_total = 1 + desired_total
|
||
|
||
# Prefer the new shape (K,R) but keep compatibility with older checkpoints
|
||
# that used a single flattened dimension (1 + K*R).
|
||
out_dims = [int(dataset.n_disease), int(r)]
|
||
if checkpoint_path:
|
||
try:
|
||
ckpt = torch.load(checkpoint_path, map_location="cpu")
|
||
head_sd = ckpt.get("head_state_dict", {})
|
||
w = head_sd.get("net.2.weight", None)
|
||
if isinstance(w, torch.Tensor) and w.ndim == 2:
|
||
out_features = int(w.shape[0])
|
||
if out_features == legacy_total:
|
||
out_dims = [legacy_total]
|
||
elif out_features == desired_total:
|
||
out_dims = [int(dataset.n_disease), int(r)]
|
||
else:
|
||
raise ValueError(
|
||
f"Checkpoint head out_features={out_features} does not match expected {desired_total} (K*R) or {legacy_total} (1+K*R)"
|
||
)
|
||
except Exception as e:
|
||
raise ValueError(
|
||
f"Failed to infer head output dims from checkpoint={checkpoint_path}: {e}"
|
||
)
|
||
loss_params["centers"] = centers
|
||
loss_params["bandwidth_min"] = float(cfg.get("bandwidth_min", 1e-3))
|
||
loss_params["bandwidth_max"] = float(cfg.get("bandwidth_max", 10.0))
|
||
loss_params["bandwidth_init"] = float(cfg.get("bandwidth_init", 0.7))
|
||
loss_params["loss_eps"] = float(cfg.get("loss_eps", 1e-8))
|
||
loss_params["alpha_floor"] = float(cfg.get("alpha_floor", 0.0))
|
||
else:
|
||
raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}")
|
||
|
||
if model_type == "delphi_fork":
|
||
backbone = DelphiFork(
|
||
n_disease=dataset.n_disease,
|
||
n_tech_tokens=2,
|
||
n_embd=int(cfg["n_embd"]),
|
||
n_head=int(cfg["n_head"]),
|
||
n_layer=int(cfg["n_layer"]),
|
||
pdrop=float(cfg.get("pdrop", 0.0)),
|
||
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
|
||
n_cont=dataset.n_cont,
|
||
n_cate=dataset.n_cate,
|
||
cate_dims=dataset.cate_dims,
|
||
).to(device)
|
||
elif model_type == "sap_delphi":
|
||
# Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path.
|
||
emb_path = cfg.get("pretrained_emb_path", None)
|
||
if emb_path in {"", None}:
|
||
emb_path = cfg.get("pretrained_emd_path", None)
|
||
if emb_path in {"", None}:
|
||
run_dir = os.path.dirname(os.path.abspath(
|
||
checkpoint_path)) if checkpoint_path else ""
|
||
print(
|
||
f"WARNING: SapDelphi pretrained embedding path missing in config "
|
||
f"(expected 'pretrained_emb_path' or 'pretrained_emd_path'). "
|
||
f"checkpoint={checkpoint_path} run_dir={run_dir}"
|
||
)
|
||
backbone = SapDelphi(
|
||
n_disease=dataset.n_disease,
|
||
n_tech_tokens=2,
|
||
n_embd=int(cfg["n_embd"]),
|
||
n_head=int(cfg["n_head"]),
|
||
n_layer=int(cfg["n_layer"]),
|
||
pdrop=float(cfg.get("pdrop", 0.0)),
|
||
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
|
||
n_cont=dataset.n_cont,
|
||
n_cate=dataset.n_cate,
|
||
cate_dims=dataset.cate_dims,
|
||
pretrained_weights_path=emb_path,
|
||
freeze_embeddings=True,
|
||
).to(device)
|
||
else:
|
||
raise ValueError(f"Unsupported model_type: {model_type}")
|
||
|
||
head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device)
|
||
bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES)
|
||
return backbone, head, loss_type, bin_edges, loss_params
|
||
|
||
|
||
@torch.no_grad()
|
||
def predict_cifs_for_model(
|
||
backbone: torch.nn.Module,
|
||
head: torch.nn.Module,
|
||
loss_type: str,
|
||
bin_edges: Sequence[float],
|
||
loss_params: Dict[str, Any],
|
||
loader: DataLoader,
|
||
device: str,
|
||
offset_years: float,
|
||
eval_horizons: Sequence[float],
|
||
n_disease: int,
|
||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||
"""Run model and produce cause-specific, time-dependent CIF outputs.
|
||
|
||
Returns:
|
||
cif_full: (N, K, H)
|
||
survival: (N, H)
|
||
y_cause_within_tau: (N, K, H)
|
||
|
||
NOTE: Evaluation is cause-specific and horizon-specific (multi-disease risk).
|
||
"""
|
||
backbone.eval()
|
||
head.eval()
|
||
|
||
# We will accumulate in CPU lists, then concat.
|
||
cif_full_list: List[np.ndarray] = []
|
||
survival_list: List[np.ndarray] = []
|
||
y_cause_within_list: List[np.ndarray] = []
|
||
sex_list: List[np.ndarray] = []
|
||
|
||
for batch in loader:
|
||
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
||
event_seq = event_seq.to(device)
|
||
time_seq = time_seq.to(device)
|
||
cont_feats = cont_feats.to(device)
|
||
cate_feats = cate_feats.to(device)
|
||
sexes = sexes.to(device)
|
||
|
||
keep, t_ctx, _ = select_context_indices(
|
||
event_seq, time_seq, offset_years)
|
||
if not keep.any():
|
||
continue
|
||
|
||
# filter batch
|
||
event_seq = event_seq[keep]
|
||
time_seq = time_seq[keep]
|
||
cont_feats = cont_feats[keep]
|
||
cate_feats = cate_feats[keep]
|
||
sexes_k = sexes[keep]
|
||
t_ctx = t_ctx[keep]
|
||
|
||
h = backbone(event_seq, time_seq, sexes_k,
|
||
cont_feats, cate_feats) # (B,L,D)
|
||
b = torch.arange(h.size(0), device=device)
|
||
c = h[b, t_ctx] # (B,D)
|
||
logits = head(c)
|
||
|
||
if loss_type == "exponential":
|
||
cif_full, survival = cifs_from_exponential_logits(
|
||
logits, eval_horizons, return_survival=True) # (B,K,H), (B,H)
|
||
elif loss_type == "discrete_time_cif":
|
||
cif_full, survival = cifs_from_discrete_time_logits(
|
||
# (B,K,H), (B,H)
|
||
logits, bin_edges, eval_horizons, return_survival=True)
|
||
elif loss_type == "lognormal_basis_binned_hazard_cif":
|
||
centers = loss_params.get("centers", None)
|
||
sigma = loss_params.get("sigma", None)
|
||
if centers is None or sigma is None:
|
||
raise ValueError(
|
||
"lognormal_basis_binned_hazard_cif requires loss_params['centers'] and loss_params['sigma']")
|
||
cif_full, survival = cifs_from_lognormal_basis_binned_hazard_logits(
|
||
logits,
|
||
centers=centers,
|
||
sigma=sigma,
|
||
bin_edges=bin_edges,
|
||
taus=eval_horizons,
|
||
eps=float(loss_params.get("loss_eps", 1e-8)),
|
||
alpha_floor=float(loss_params.get("alpha_floor", 0.0)),
|
||
return_survival=True,
|
||
)
|
||
else:
|
||
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
||
|
||
# Within-horizon labels for all causes: disease k occurs within tau after context.
|
||
y_within_full = torch.stack(
|
||
[
|
||
multi_hot_ever_within_horizon(
|
||
event_seq=event_seq,
|
||
time_seq=time_seq,
|
||
t_ctx=t_ctx,
|
||
tau_years=float(tau),
|
||
n_disease=int(n_disease),
|
||
).to(torch.float32)
|
||
for tau in eval_horizons
|
||
],
|
||
dim=2,
|
||
) # (B,K,H)
|
||
|
||
cif_full_list.append(cif_full.detach().cpu().numpy())
|
||
survival_list.append(survival.detach().cpu().numpy())
|
||
y_cause_within_list.append(y_within_full.detach().cpu().numpy())
|
||
sex_list.append(sexes_k.detach().cpu().numpy())
|
||
|
||
if not cif_full_list:
|
||
raise RuntimeError(
|
||
"No valid samples for evaluation (all batches filtered out by offset).")
|
||
|
||
cif_full = np.concatenate(cif_full_list, axis=0)
|
||
survival = np.concatenate(survival_list, axis=0)
|
||
y_cause_within = np.concatenate(y_cause_within_list, axis=0)
|
||
sex = np.concatenate(
|
||
sex_list, axis=0) if sex_list else np.array([], dtype=int)
|
||
|
||
return cif_full, survival, y_cause_within, sex
|
||
|
||
|
||
def evaluate_one_model(
|
||
model_name: str,
|
||
cif_full: np.ndarray,
|
||
y_cause_within_tau: np.ndarray,
|
||
eval_horizons: Sequence[float],
|
||
out_rows: List[Dict[str, Any]],
|
||
calib_rows: List[Dict[str, Any]],
|
||
calib_cause_ids: Optional[Sequence[int]],
|
||
n_calib_bins: int = 10,
|
||
metric_workers: int = 0,
|
||
progress: str = "auto",
|
||
) -> None:
|
||
"""Compute per-cause metrics for ALL diseases.
|
||
|
||
Notes:
|
||
- Writes scalar metrics for all causes into out_rows.
|
||
- Writes calibration-bin points only for calib_cause_ids (to keep outputs tractable).
|
||
"""
|
||
cif_full = np.asarray(cif_full, dtype=float)
|
||
y_cause_within_tau = np.asarray(y_cause_within_tau, dtype=float)
|
||
if cif_full.ndim != 3 or y_cause_within_tau.ndim != 3:
|
||
raise ValueError(
|
||
"Expected cif_full and y_cause_within_tau with shape (N, K, H)")
|
||
if cif_full.shape != y_cause_within_tau.shape:
|
||
raise ValueError(
|
||
f"Shape mismatch: cif_full {cif_full.shape} vs y_cause_within_tau {y_cause_within_tau.shape}"
|
||
)
|
||
|
||
N, K, H = cif_full.shape
|
||
if H != len(eval_horizons):
|
||
raise ValueError("H mismatch between cif_full and eval_horizons")
|
||
|
||
calib_set = set(int(x)
|
||
for x in calib_cause_ids) if calib_cause_ids is not None else set()
|
||
|
||
workers = int(metric_workers)
|
||
if workers <= 0:
|
||
workers = int(min(8, os.cpu_count() or 1))
|
||
workers = max(1, workers)
|
||
show_progress = _should_show_progress(progress)
|
||
|
||
def _eval_chunk(
|
||
*,
|
||
tau: float,
|
||
p_tau: np.ndarray,
|
||
y_tau: np.ndarray,
|
||
brier_by_cause: np.ndarray,
|
||
cause_ids: np.ndarray,
|
||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], int]:
|
||
local_rows: List[Dict[str, Any]] = []
|
||
local_calib: List[Dict[str, Any]] = []
|
||
for cid in cause_ids.tolist():
|
||
p = p_tau[:, cid]
|
||
y = y_tau[:, cid]
|
||
|
||
local_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "cause_brier",
|
||
"horizon": float(tau),
|
||
"cause": int(cid),
|
||
"value": float(brier_by_cause[cid]),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
|
||
# ICI: compute bins only if we will export them.
|
||
need_bins = (not calib_set) or (int(cid) in calib_set)
|
||
if need_bins:
|
||
cal = calibration_deciles(p, y, n_bins=n_calib_bins)
|
||
ici = float(cal["ici"])
|
||
else:
|
||
cal = None
|
||
ici = calibration_ici_only(p, y, n_bins=n_calib_bins)
|
||
|
||
local_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "cause_ici",
|
||
"horizon": float(tau),
|
||
"cause": int(cid),
|
||
"value": float(ici),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
|
||
# Secondary: discrimination via AUC at the same horizon (point estimate only).
|
||
auc = roc_auc_rank(y, p)
|
||
|
||
local_rows.append(
|
||
{
|
||
"model_name": model_name,
|
||
"metric_name": "cause_auc",
|
||
"horizon": float(tau),
|
||
"cause": int(cid),
|
||
"value": float(auc),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
|
||
if need_bins and cal is not None:
|
||
for binfo in cal.get("bins", []):
|
||
local_calib.append(
|
||
{
|
||
"model_name": model_name,
|
||
"task": "cause_k",
|
||
"horizon": float(tau),
|
||
"cause_id": int(cid),
|
||
"bin_index": int(binfo["bin"]),
|
||
"p_mean": float(binfo["p_mean"]),
|
||
"y_mean": float(binfo["y_mean"]),
|
||
"n_in_bin": int(binfo["n"]),
|
||
}
|
||
)
|
||
return local_rows, local_calib, int(cause_ids.size)
|
||
|
||
# Cause-specific, time-dependent metrics per horizon.
|
||
for h_i, tau in enumerate(eval_horizons):
|
||
p_tau = cif_full[:, :, h_i] # (N, K)
|
||
y_tau = y_cause_within_tau[:, :, h_i] # (N, K)
|
||
|
||
# Vectorized Brier for speed.
|
||
brier_by_cause = np.mean((p_tau - y_tau) ** 2, axis=0) # (K,)
|
||
|
||
# Parallelize disease-level metrics; chunk to avoid millions of futures.
|
||
all_ids = np.arange(int(K), dtype=int)
|
||
chunks = np.array_split(all_ids, workers)
|
||
done = 0
|
||
prefix = f"[{model_name}] tau={float(tau)}y "
|
||
t0 = time.time()
|
||
|
||
if workers <= 1:
|
||
for ch in chunks:
|
||
r_chunk, c_chunk, n_done = _eval_chunk(
|
||
tau=float(tau),
|
||
p_tau=p_tau,
|
||
y_tau=y_tau,
|
||
brier_by_cause=brier_by_cause,
|
||
cause_ids=ch,
|
||
)
|
||
out_rows.extend(r_chunk)
|
||
calib_rows.extend(c_chunk)
|
||
done += int(n_done)
|
||
if show_progress:
|
||
sys.stdout.write(
|
||
"\r" + _progress_line(done, int(K), prefix=prefix))
|
||
sys.stdout.flush()
|
||
else:
|
||
with ThreadPoolExecutor(max_workers=workers) as ex:
|
||
futs = [
|
||
ex.submit(
|
||
_eval_chunk,
|
||
tau=float(tau),
|
||
p_tau=p_tau,
|
||
y_tau=y_tau,
|
||
brier_by_cause=brier_by_cause,
|
||
cause_ids=ch,
|
||
)
|
||
for ch in chunks
|
||
if int(ch.size) > 0
|
||
]
|
||
for fut in as_completed(futs):
|
||
r_chunk, c_chunk, n_done = fut.result()
|
||
out_rows.extend(r_chunk)
|
||
calib_rows.extend(c_chunk)
|
||
done += int(n_done)
|
||
if show_progress:
|
||
sys.stdout.write(
|
||
"\r" + _progress_line(done, int(K), prefix=prefix))
|
||
sys.stdout.flush()
|
||
|
||
if show_progress:
|
||
dt = time.time() - t0
|
||
sys.stdout.write("\r" + _progress_line(int(K),
|
||
int(K), prefix=prefix) + f" ({dt:.1f}s)\n")
|
||
sys.stdout.flush()
|
||
|
||
|
||
def summarize_over_diseases(
|
||
rows: List[Dict[str, Any]],
|
||
*,
|
||
model_name: str,
|
||
eval_horizons: Sequence[float],
|
||
metrics: Sequence[str] = ("cause_brier", "cause_ici", "cause_auc"),
|
||
) -> List[Dict[str, Any]]:
|
||
"""Summarize mean/median of each metric over diseases (per horizon)."""
|
||
out: List[Dict[str, Any]] = []
|
||
# Build metric_name -> horizon -> list of values
|
||
bucket: Dict[Tuple[str, float], List[float]] = {}
|
||
for r in rows:
|
||
if r.get("model_name") != model_name:
|
||
continue
|
||
m = str(r.get("metric_name"))
|
||
if m not in set(metrics):
|
||
continue
|
||
h = _safe_float(r.get("horizon"))
|
||
v = _safe_float(r.get("value"))
|
||
if not np.isfinite(h):
|
||
continue
|
||
if not np.isfinite(v):
|
||
continue
|
||
bucket.setdefault((m, float(h)), []).append(float(v))
|
||
|
||
for tau in eval_horizons:
|
||
ht = float(tau)
|
||
for m in metrics:
|
||
vals = bucket.get((str(m), ht), [])
|
||
if vals:
|
||
arr = np.asarray(vals, dtype=float)
|
||
mean_v = float(np.mean(arr))
|
||
med_v = float(np.median(arr))
|
||
n_valid = int(arr.size)
|
||
else:
|
||
mean_v = float("nan")
|
||
med_v = float("nan")
|
||
n_valid = 0
|
||
out.append(
|
||
{
|
||
"model_name": str(model_name),
|
||
"metric_name": str(m),
|
||
"horizon": ht,
|
||
"mean": mean_v,
|
||
"median": med_v,
|
||
"n_valid": n_valid,
|
||
}
|
||
)
|
||
return out
|
||
|
||
|
||
def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None:
|
||
fieldnames = [
|
||
"model_name",
|
||
"task",
|
||
"horizon",
|
||
"cause_id",
|
||
"bin_index",
|
||
"p_mean",
|
||
"y_mean",
|
||
"n_in_bin",
|
||
]
|
||
with open(path, "w", newline="") as f:
|
||
w = csv.DictWriter(f, fieldnames=fieldnames)
|
||
w.writeheader()
|
||
for r in rows:
|
||
w.writerow(r)
|
||
|
||
|
||
def write_results_csv(path: str, rows: List[Dict[str, Any]]) -> None:
|
||
fieldnames = [
|
||
"model_name",
|
||
"metric_name",
|
||
"horizon",
|
||
"cause",
|
||
"value",
|
||
"ci_low",
|
||
"ci_high",
|
||
]
|
||
with open(path, "w", newline="") as f:
|
||
w = csv.DictWriter(f, fieldnames=fieldnames)
|
||
w.writeheader()
|
||
for r in rows:
|
||
w.writerow(r)
|
||
|
||
|
||
def _make_eval_tag(split: str, offset_years: float) -> str:
|
||
"""Short tag for filenames written into run directories."""
|
||
off = f"{float(offset_years):.4f}".rstrip("0").rstrip(".")
|
||
return f"{split}_offset{off}y"
|
||
|
||
|
||
def main() -> int:
|
||
ap = argparse.ArgumentParser(
|
||
description="Unified downstream evaluation via CIFs")
|
||
ap.add_argument("--models_json", type=str, required=True,
|
||
help="Path to models list JSON")
|
||
ap.add_argument("--data_prefix", type=str,
|
||
default="ukb", help="Dataset prefix")
|
||
ap.add_argument("--split", type=str, default="test",
|
||
choices=["train", "val", "test", "all"], help="Which split to evaluate")
|
||
ap.add_argument("--offset_years", type=float, default=0.5,
|
||
help="Anti-leakage offset (years)")
|
||
ap.add_argument("--eval_horizons", type=float,
|
||
nargs="*", default=DEFAULT_EVAL_HORIZONS)
|
||
ap.add_argument("--batch_size", type=int, default=128)
|
||
ap.add_argument("--num_workers", type=int, default=0)
|
||
ap.add_argument("--seed", type=int, default=123)
|
||
ap.add_argument("--device", type=str,
|
||
default="cuda" if torch.cuda.is_available() else "cpu")
|
||
ap.add_argument("--out_csv", type=str, default="eval_summary.csv")
|
||
ap.add_argument("--out_meta_json", type=str, default="eval_meta.json")
|
||
|
||
# Integrity checks
|
||
ap.add_argument("--integrity_strict", action="store_true", default=False)
|
||
ap.add_argument("--integrity_tol", type=float, default=1e-6)
|
||
|
||
# Speed/UX
|
||
ap.add_argument(
|
||
"--metric_workers",
|
||
type=int,
|
||
default=0,
|
||
help="Threads for per-disease metrics (0=auto, 1=disable parallelism)",
|
||
)
|
||
ap.add_argument(
|
||
"--progress",
|
||
type=str,
|
||
default="auto",
|
||
choices=["auto", "bar", "none"],
|
||
help="Progress visualization during per-disease evaluation",
|
||
)
|
||
|
||
# Export settings for user-facing experiments
|
||
ap.add_argument("--export_dir", type=str, default="eval_exports")
|
||
ap.add_argument("--death_cause_id", type=int,
|
||
default=DEFAULT_DEATH_CAUSE_ID)
|
||
ap.add_argument("--focus_k", type=int, default=5,
|
||
help="Additional non-death causes to include")
|
||
ap.add_argument("--capture_k_pcts", type=int,
|
||
nargs="*", default=[1, 5, 10, 20])
|
||
ap.add_argument(
|
||
"--capture_curve_max_pct",
|
||
type=int,
|
||
default=50,
|
||
help="If >0, also export a dense capture curve for k=1..max_pct",
|
||
)
|
||
|
||
# High-risk cause concentration (cross-cause prioritization)
|
||
ap.add_argument(
|
||
"--cause_concentration_topk",
|
||
type=int,
|
||
nargs="*",
|
||
default=[5, 10, 20, 50],
|
||
help="Top-K causes per individual for Event Rate@K (cross-cause prioritization)",
|
||
)
|
||
ap.add_argument(
|
||
"--cause_concentration_write_random_baseline",
|
||
action="store_true",
|
||
default=False,
|
||
help="If set, also export a random-ranking baseline (expected Event Rate@K and Recall@K with an uncertainty range)",
|
||
)
|
||
args = ap.parse_args()
|
||
|
||
set_deterministic(args.seed)
|
||
|
||
specs = load_models_json(args.models_json)
|
||
if not specs:
|
||
raise ValueError("No models provided")
|
||
|
||
export_dir = str(args.export_dir)
|
||
_ensure_dir(export_dir)
|
||
cause_names = load_cause_names("labels.csv")
|
||
|
||
# Determine top-K causes from the evaluation split only (model-agnostic),
|
||
# aligned to time-dependent risk: occurrence within tau_max after context.
|
||
first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path)
|
||
cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [
|
||
"bmi", "smoking", "alcohol"]
|
||
dataset_for_top = HealthDataset(
|
||
data_prefix=args.data_prefix, covariate_list=cov_list)
|
||
subset_for_top = build_eval_subset(
|
||
dataset_for_top,
|
||
train_ratio=float(first_cfg.get("train_ratio", 0.7)),
|
||
val_ratio=float(first_cfg.get("val_ratio", 0.15)),
|
||
seed=int(first_cfg.get("random_seed", 42)),
|
||
split=args.split,
|
||
)
|
||
loader_top = DataLoader(
|
||
subset_for_top,
|
||
batch_size=args.batch_size,
|
||
shuffle=False,
|
||
num_workers=args.num_workers,
|
||
collate_fn=health_collate_fn,
|
||
)
|
||
|
||
tau_max = float(max(args.eval_horizons))
|
||
counts, n_total_eval = count_occurs_within_horizon(
|
||
loader=loader_top,
|
||
offset_years=args.offset_years,
|
||
tau_years=tau_max,
|
||
n_disease=dataset_for_top.n_disease,
|
||
device=args.device,
|
||
)
|
||
|
||
focus_causes = pick_focus_causes(
|
||
counts_within_tau=counts,
|
||
n_disease=int(dataset_for_top.n_disease),
|
||
death_cause_id=int(args.death_cause_id),
|
||
k=int(args.focus_k),
|
||
)
|
||
top_cause_ids = np.asarray(focus_causes, dtype=int)
|
||
|
||
# Export the chosen focus causes.
|
||
focus_rows: List[Dict[str, Any]] = []
|
||
for r, cid in enumerate(focus_causes, start=1):
|
||
row: Dict[str, Any] = {"cause": int(cid), "rank": int(r)}
|
||
if cid in cause_names:
|
||
row["cause_name"] = cause_names[cid]
|
||
focus_rows.append(row)
|
||
focus_fieldnames = ["cause", "rank"] + \
|
||
(["cause_name"] if any("cause_name" in r for r in focus_rows) else [])
|
||
write_simple_csv(os.path.join(export_dir, "focus_causes.csv"),
|
||
focus_fieldnames, focus_rows)
|
||
|
||
# Metadata for focus causes (within tau_max).
|
||
top_causes_meta: List[Dict[str, Any]] = []
|
||
for cid in focus_causes:
|
||
n_case = int(counts[int(cid)]) if int(
|
||
cid) < int(counts.shape[0]) else 0
|
||
top_causes_meta.append(
|
||
{
|
||
"cause_id": int(cid),
|
||
"tau_years": float(tau_max),
|
||
"n_case_within_tau": n_case,
|
||
"n_control_within_tau": int(n_total_eval - n_case),
|
||
"n_total_eval": int(n_total_eval),
|
||
}
|
||
)
|
||
|
||
summary_rows: List[Dict[str, Any]] = []
|
||
calib_rows: List[Dict[str, Any]] = []
|
||
|
||
# Experiment exports (accumulated across models)
|
||
cap_points_rows: List[Dict[str, Any]] = []
|
||
cap_curve_rows: List[Dict[str, Any]] = []
|
||
conc_rows: List[Dict[str, Any]] = []
|
||
conc_base_rows: List[Dict[str, Any]] = []
|
||
|
||
# Track per-model integrity status for meta JSON.
|
||
integrity_meta: Dict[str, Any] = {}
|
||
|
||
# Evaluate each model
|
||
for spec in specs:
|
||
run_dir = os.path.dirname(os.path.abspath(spec.checkpoint_path))
|
||
tag = _make_eval_tag(args.split, float(args.offset_years))
|
||
|
||
# Remember list offsets so we can write per-model slices to the model's run_dir.
|
||
calib_start = len(calib_rows)
|
||
|
||
cfg = load_train_config_for_checkpoint(spec.checkpoint_path)
|
||
|
||
# Identifiers for consistent exports
|
||
model_id = str(spec.name)
|
||
model_type = str(
|
||
cfg.get("model_type", spec.model_type if hasattr(spec, "model_type") else ""))
|
||
loss_type_id = str(
|
||
cfg.get("loss_type", spec.loss_type if hasattr(spec, "loss_type") else ""))
|
||
age_encoder = str(cfg.get("age_encoder", ""))
|
||
cov_type = "full" if _parse_bool(
|
||
cfg.get("full_cov", False)) else "partial"
|
||
|
||
cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [
|
||
"bmi", "smoking", "alcohol"]
|
||
dataset = HealthDataset(
|
||
data_prefix=args.data_prefix, covariate_list=cov_list)
|
||
subset = build_eval_subset(
|
||
dataset,
|
||
train_ratio=float(cfg.get("train_ratio", 0.7)),
|
||
val_ratio=float(cfg.get("val_ratio", 0.15)),
|
||
seed=int(cfg.get("random_seed", 42)),
|
||
split=args.split,
|
||
)
|
||
loader = DataLoader(
|
||
subset,
|
||
batch_size=args.batch_size,
|
||
shuffle=False,
|
||
num_workers=args.num_workers,
|
||
collate_fn=health_collate_fn,
|
||
)
|
||
|
||
ckpt = torch.load(spec.checkpoint_path, map_location=args.device)
|
||
|
||
backbone, head, loss_type, bin_edges, loss_params = instantiate_model_and_head(
|
||
cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path)
|
||
backbone.load_state_dict(ckpt["model_state_dict"], strict=True)
|
||
head.load_state_dict(ckpt["head_state_dict"], strict=True)
|
||
|
||
if loss_type == "lognormal_basis_binned_hazard_cif":
|
||
crit_state = ckpt.get("criterion_state_dict", {})
|
||
log_sigma = crit_state.get("log_sigma", None)
|
||
if isinstance(log_sigma, torch.Tensor):
|
||
log_sigma_t = log_sigma.to(device=args.device)
|
||
sigma = torch.exp(log_sigma_t)
|
||
else:
|
||
sigma = torch.tensor(float(loss_params.get(
|
||
"bandwidth_init", 0.7)), device=args.device)
|
||
bmin = float(loss_params.get("bandwidth_min", 1e-3))
|
||
bmax = float(loss_params.get("bandwidth_max", 10.0))
|
||
sigma = torch.clamp(sigma, min=bmin, max=bmax)
|
||
loss_params["sigma"] = sigma
|
||
|
||
(
|
||
cif_full,
|
||
survival,
|
||
y_cause_within_tau,
|
||
sex,
|
||
) = predict_cifs_for_model(
|
||
backbone,
|
||
head,
|
||
loss_type,
|
||
bin_edges,
|
||
loss_params,
|
||
loader,
|
||
args.device,
|
||
args.offset_years,
|
||
args.eval_horizons,
|
||
n_disease=int(dataset.n_disease),
|
||
)
|
||
|
||
# CIF integrity checks before metrics.
|
||
integrity_ok, integrity_notes = check_cif_integrity(
|
||
cif_full,
|
||
args.eval_horizons,
|
||
tol=float(args.integrity_tol),
|
||
name=spec.name,
|
||
strict=bool(args.integrity_strict),
|
||
survival=survival,
|
||
)
|
||
integrity_meta[spec.name] = {
|
||
"integrity_ok": bool(integrity_ok),
|
||
"integrity_notes": integrity_notes,
|
||
}
|
||
|
||
# Per-disease metrics for ALL diseases (written into the model's run_dir).
|
||
model_rows: List[Dict[str, Any]] = []
|
||
evaluate_one_model(
|
||
model_name=spec.name,
|
||
cif_full=cif_full,
|
||
y_cause_within_tau=y_cause_within_tau,
|
||
eval_horizons=args.eval_horizons,
|
||
out_rows=model_rows,
|
||
calib_rows=calib_rows,
|
||
calib_cause_ids=top_cause_ids.tolist(),
|
||
metric_workers=int(args.metric_workers),
|
||
progress=str(args.progress),
|
||
)
|
||
|
||
# Summary over diseases (mean/median per horizon).
|
||
model_summary_rows = summarize_over_diseases(
|
||
model_rows,
|
||
model_name=spec.name,
|
||
eval_horizons=args.eval_horizons,
|
||
)
|
||
summary_rows.extend(model_summary_rows)
|
||
|
||
# ============================================================
|
||
# Experiment: High-Risk Cause Concentration at fixed horizon
|
||
# (cross-cause prioritization accuracy)
|
||
# ============================================================
|
||
topk_causes = [int(x) for x in args.cause_concentration_topk]
|
||
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
|
||
for h_i, tau in enumerate(args.eval_horizons):
|
||
p_tau_all = np.asarray(cif_full[:, :, h_i], dtype=float)
|
||
y_tau_all = np.asarray(
|
||
y_cause_within_tau[:, :, h_i], dtype=float)
|
||
if sex_mask is not None:
|
||
p_tau_all = p_tau_all[sex_mask]
|
||
y_tau_all = y_tau_all[sex_mask]
|
||
for rr in compute_event_rate_at_topk_causes(p_tau_all, y_tau_all, topk_causes):
|
||
conc_rows.append(
|
||
{
|
||
"model_id": model_id,
|
||
"model_type": model_type,
|
||
"loss_type": loss_type_id,
|
||
"age_encoder": age_encoder,
|
||
"cov_type": cov_type,
|
||
"horizon": float(tau),
|
||
"sex": sex_label,
|
||
**rr,
|
||
}
|
||
)
|
||
|
||
if bool(args.cause_concentration_write_random_baseline):
|
||
for rr in compute_random_ranking_baseline_topk(y_tau_all, topk_causes):
|
||
conc_base_rows.append(
|
||
{
|
||
"model_id": model_id,
|
||
"model_type": model_type,
|
||
"loss_type": loss_type_id,
|
||
"age_encoder": age_encoder,
|
||
"cov_type": cov_type,
|
||
"horizon": float(tau),
|
||
"sex": sex_label,
|
||
**rr,
|
||
}
|
||
)
|
||
|
||
# Convenience slices for user-facing experiments (focus causes only).
|
||
cause_cif_focus = cif_full[:, top_cause_ids, :]
|
||
y_within_focus = y_cause_within_tau[:, top_cause_ids, :]
|
||
|
||
# ============================================================
|
||
# Experiment 2: High-risk capture points (+ optional curve)
|
||
# ============================================================
|
||
k_pcts = [int(x) for x in args.capture_k_pcts]
|
||
curve_max = int(args.capture_curve_max_pct)
|
||
curve_grid = list(range(1, curve_max + 1)
|
||
) if curve_max and curve_max > 0 else []
|
||
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
|
||
for h_i, tau in enumerate(args.eval_horizons):
|
||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||
p = cause_cif_focus[:, j, h_i]
|
||
y = y_within_focus[:, j, h_i]
|
||
if sex_mask is not None:
|
||
p = p[sex_mask]
|
||
y = y[sex_mask]
|
||
|
||
for r in compute_capture_points(p, y, k_pcts):
|
||
cap_points_rows.append(
|
||
{
|
||
"model_id": model_id,
|
||
"model_type": model_type,
|
||
"loss_type": loss_type_id,
|
||
"age_encoder": age_encoder,
|
||
"cov_type": cov_type,
|
||
"cause": int(cause_id),
|
||
"horizon": float(tau),
|
||
"sex": sex_label,
|
||
**r,
|
||
}
|
||
)
|
||
if curve_grid:
|
||
for r in compute_capture_points(p, y, curve_grid):
|
||
cap_curve_rows.append(
|
||
{
|
||
"model_id": model_id,
|
||
"model_type": model_type,
|
||
"loss_type": loss_type_id,
|
||
"age_encoder": age_encoder,
|
||
"cov_type": cov_type,
|
||
"cause": int(cause_id),
|
||
"horizon": float(tau),
|
||
"sex": sex_label,
|
||
**r,
|
||
}
|
||
)
|
||
|
||
# Optionally write top-cause counts into the main results CSV as metric rows.
|
||
for tc in top_causes_meta:
|
||
model_rows.append(
|
||
{
|
||
"model_name": spec.name,
|
||
"metric_name": "topcause_n_case_within_tau",
|
||
"horizon": float(tc["tau_years"]),
|
||
"cause": int(tc["cause_id"]),
|
||
"value": int(tc["n_case_within_tau"]),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
model_rows.append(
|
||
{
|
||
"model_name": spec.name,
|
||
"metric_name": "topcause_n_control_within_tau",
|
||
"horizon": float(tc["tau_years"]),
|
||
"cause": int(tc["cause_id"]),
|
||
"value": int(tc["n_control_within_tau"]),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
model_rows.append(
|
||
{
|
||
"model_name": spec.name,
|
||
"metric_name": "topcause_n_total_eval",
|
||
"horizon": float(tc["tau_years"]),
|
||
"cause": int(tc["cause_id"]),
|
||
"value": int(tc["n_total_eval"]),
|
||
"ci_low": "",
|
||
"ci_high": "",
|
||
}
|
||
)
|
||
|
||
# Write per-model results into the model's run directory.
|
||
model_calib_rows = calib_rows[calib_start:]
|
||
model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv")
|
||
model_summary_csv = os.path.join(run_dir, f"eval_summary_{tag}.csv")
|
||
model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv")
|
||
model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json")
|
||
|
||
write_results_csv(model_out_csv, model_rows)
|
||
write_simple_csv(
|
||
model_summary_csv,
|
||
["model_name", "metric_name", "horizon", "mean", "median", "n_valid"],
|
||
model_summary_rows,
|
||
)
|
||
write_calibration_bins_csv(model_calib_csv, model_calib_rows)
|
||
|
||
model_meta = {
|
||
"model_name": spec.name,
|
||
"checkpoint_path": spec.checkpoint_path,
|
||
"run_dir": run_dir,
|
||
"split": args.split,
|
||
"offset_years": args.offset_years,
|
||
"eval_horizons": [float(x) for x in args.eval_horizons],
|
||
"n_disease": int(dataset.n_disease),
|
||
"top_cause_ids": top_cause_ids.tolist(),
|
||
"top_causes": top_causes_meta,
|
||
"integrity": {spec.name: integrity_meta.get(spec.name, {})},
|
||
"paths": {
|
||
"results_csv": model_out_csv,
|
||
"summary_csv": model_summary_csv,
|
||
"calibration_bins_csv": model_calib_csv,
|
||
},
|
||
}
|
||
with open(model_meta_json, "w") as f:
|
||
json.dump(model_meta, f, indent=2)
|
||
|
||
print(f"Wrote per-model results to {model_out_csv}")
|
||
|
||
# Write global summary (across diseases) across all models.
|
||
write_simple_csv(
|
||
args.out_csv,
|
||
["model_name", "metric_name", "horizon", "mean", "median", "n_valid"],
|
||
summary_rows,
|
||
)
|
||
|
||
# Write calibration curve points to a separate CSV.
|
||
out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "."
|
||
calib_csv_path = os.path.join(out_dir, "calibration_bins.csv")
|
||
write_calibration_bins_csv(calib_csv_path, calib_rows)
|
||
|
||
# Write experiment exports
|
||
write_simple_csv(
|
||
os.path.join(export_dir, "lift_capture_points.csv"),
|
||
[
|
||
"model_id",
|
||
"model_type",
|
||
"loss_type",
|
||
"age_encoder",
|
||
"cov_type",
|
||
"cause",
|
||
"horizon",
|
||
"sex",
|
||
"k_pct",
|
||
"n_targeted",
|
||
"events_targeted",
|
||
"events_total",
|
||
"event_capture_rate",
|
||
"precision_in_targeted",
|
||
],
|
||
cap_points_rows,
|
||
)
|
||
if cap_curve_rows:
|
||
write_simple_csv(
|
||
os.path.join(export_dir, "lift_capture_curve.csv"),
|
||
[
|
||
"model_id",
|
||
"model_type",
|
||
"loss_type",
|
||
"age_encoder",
|
||
"cov_type",
|
||
"cause",
|
||
"horizon",
|
||
"sex",
|
||
"k_pct",
|
||
"n_targeted",
|
||
"events_targeted",
|
||
"events_total",
|
||
"event_capture_rate",
|
||
"precision_in_targeted",
|
||
],
|
||
cap_curve_rows,
|
||
)
|
||
|
||
write_simple_csv(
|
||
os.path.join(export_dir, "high_risk_cause_concentration.csv"),
|
||
[
|
||
"model_id",
|
||
"model_type",
|
||
"loss_type",
|
||
"age_encoder",
|
||
"cov_type",
|
||
"horizon",
|
||
"sex",
|
||
"topk",
|
||
"event_rate_mean",
|
||
"event_rate_median",
|
||
"recall_mean",
|
||
"recall_median",
|
||
"n_total",
|
||
"n_valid_recall",
|
||
],
|
||
conc_rows,
|
||
)
|
||
|
||
if conc_base_rows:
|
||
write_simple_csv(
|
||
os.path.join(
|
||
export_dir, "high_risk_cause_concentration_random_baseline.csv"),
|
||
[
|
||
"model_id",
|
||
"model_type",
|
||
"loss_type",
|
||
"age_encoder",
|
||
"cov_type",
|
||
"horizon",
|
||
"sex",
|
||
"topk",
|
||
"baseline_event_rate_mean",
|
||
"baseline_event_rate_p05",
|
||
"baseline_event_rate_p95",
|
||
"baseline_recall_mean",
|
||
"baseline_recall_p05",
|
||
"baseline_recall_p95",
|
||
"n_total",
|
||
"n_valid_recall",
|
||
"k_total",
|
||
"baseline_method",
|
||
],
|
||
conc_base_rows,
|
||
)
|
||
|
||
# Manifest markdown (stable, user-facing)
|
||
manifest_path = os.path.join(export_dir, "eval_exports_manifest.md")
|
||
with open(manifest_path, "w", encoding="utf-8") as f:
|
||
f.write(
|
||
"# Evaluation Exports Manifest\n\n"
|
||
"This folder contains user-facing CSV artifacts for multi-disease, cause-specific, time-dependent risk evaluation (CIF-based). "
|
||
"All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n"
|
||
"## Files\n\n"
|
||
"- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n"
|
||
"- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n"
|
||
"- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n"
|
||
"- high_risk_cause_concentration.csv: Event Rate@K and Recall@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart vs K.\n"
|
||
"- high_risk_cause_concentration_random_baseline.csv (optional): Random-ranking baseline for Event Rate@K and Recall@K with an uncertainty range (enabled by --cause_concentration_write_random_baseline).\n"
|
||
)
|
||
|
||
meta = {
|
||
"split": args.split,
|
||
"offset_years": args.offset_years,
|
||
"eval_horizons": [float(x) for x in args.eval_horizons],
|
||
"tau_max": float(tau_max),
|
||
"n_disease": int(dataset_for_top.n_disease),
|
||
"top_cause_ids": top_cause_ids.tolist(),
|
||
"top_causes": top_causes_meta,
|
||
"integrity": integrity_meta,
|
||
"notes": {
|
||
"label": "Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau])",
|
||
"primary_metrics": "cause_brier (CIF-based) and cause_ici (calibration)",
|
||
"secondary_metrics": "cause_auc (discrimination)",
|
||
"exclusions": "No all-cause aggregation; no next-event formulation; ECE not reported",
|
||
"warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.",
|
||
"exports_dir": export_dir,
|
||
"focus_causes": focus_causes,
|
||
},
|
||
}
|
||
with open(args.out_meta_json, "w") as f:
|
||
json.dump(meta, f, indent=2)
|
||
|
||
print(f"Wrote {args.out_csv} with {len(summary_rows)} rows")
|
||
print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows")
|
||
print(f"Wrote {args.out_meta_json}")
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|