diff --git a/evaluate_models.py b/evaluate_models.py new file mode 100644 index 0000000..ac64607 --- /dev/null +++ b/evaluate_models.py @@ -0,0 +1,1553 @@ +import argparse +import csv +import json +import math +import os +import random +import statistics +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, random_split + +from dataset import HealthDataset, health_collate_fn +from losses import DiscreteTimeCIFNLLLoss +from model import DelphiFork, SapDelphi, SimpleHead + + +# ============================================================ +# Constants / defaults (aligned with evaluate_prompt.md) +# ============================================================ +DEFAULT_BIN_EDGES = [0.0, 0.24, 0.72, 1.61, 3.84, 10.0, 31.0, float("inf")] +DEFAULT_EVAL_HORIZONS = [0.72, 1.61, 3.84, 10.0] +DAYS_PER_YEAR = 365.25 + + +# ============================================================ +# Model specs +# ============================================================ +@dataclass(frozen=True) +class ModelSpec: + name: str + model_type: str # delphi_fork | sap_delphi + loss_type: str # exponential | discrete_time_cif + full_cov: bool + checkpoint_path: str + + +# ============================================================ +# Determinism +# ============================================================ + +def set_deterministic(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +# ============================================================ +# Utilities +# ============================================================ + +def _parse_bool(x: Any) -> bool: + if isinstance(x, bool): + return x + s = str(x).strip().lower() + if s in {"true", "1", "yes", "y"}: + return True + if s in {"false", "0", "no", "n"}: + return False + raise ValueError(f"Cannot parse boolean: {x!r}") + + +def load_models_json(path: str) -> List[ModelSpec]: + with open(path, "r") as f: + data = json.load(f) + if not isinstance(data, list): + raise ValueError("models_json must be a list of model entries") + + specs: List[ModelSpec] = [] + for row in data: + specs.append( + ModelSpec( + name=str(row["name"]), + model_type=str(row["model_type"]), + loss_type=str(row["loss_type"]), + full_cov=_parse_bool(row["full_cov"]), + checkpoint_path=str(row["checkpoint_path"]), + ) + ) + return specs + + +def load_train_config_for_checkpoint(checkpoint_path: str) -> Dict[str, Any]: + run_dir = os.path.dirname(os.path.abspath(checkpoint_path)) + cfg_path = os.path.join(run_dir, "train_config.json") + with open(cfg_path, "r") as f: + cfg = json.load(f) + return cfg + + +def build_eval_subset( + dataset: HealthDataset, + train_ratio: float, + val_ratio: float, + seed: int, + split: str, +): + n_total = len(dataset) + n_train = int(n_total * train_ratio) + n_val = int(n_total * val_ratio) + n_test = n_total - n_train - n_val + + train_ds, val_ds, test_ds = random_split( + dataset, + [n_train, n_val, n_test], + generator=torch.Generator().manual_seed(seed), + ) + + if split == "train": + return train_ds + if split == "val": + return val_ds + if split == "test": + return test_ds + if split == "all": + return dataset + raise ValueError("split must be one of: train, val, test, all") + + +# ============================================================ +# Context selection (anti-leakage) +# ============================================================ + +def select_context_indices( + event_seq: torch.Tensor, + time_seq: torch.Tensor, + offset_years: float, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Select per-sample prediction context index. + + IMPORTANT SEMANTICS: + - The last observed token time is treated as the FOLLOW-UP END time. + - We pick the last valid token with time <= (followup_end_time - offset). + - We do NOT interpret followup_end_time as an event time. + + Returns: + keep_mask: (B,) bool, which samples have a valid context + t_ctx: (B,) long, index into sequence + t_ctx_time: (B,) float, time (days) at context + """ + # valid tokens are event != 0 (padding is 0) + valid = event_seq != 0 + lengths = valid.sum(dim=1) + last_idx = torch.clamp(lengths - 1, min=0) + + b = torch.arange(event_seq.size(0), device=event_seq.device) + followup_end_time = time_seq[b, last_idx] + t_cut = followup_end_time - (offset_years * DAYS_PER_YEAR) + + eligible = valid & (time_seq <= t_cut.unsqueeze(1)) + eligible_counts = eligible.sum(dim=1) + keep = eligible_counts > 0 + + t_ctx = torch.clamp(eligible_counts - 1, min=0).to(torch.long) + t_ctx_time = time_seq[b, t_ctx] + + return keep, t_ctx, t_ctx_time + + +def next_event_after_context( + event_seq: torch.Tensor, + time_seq: torch.Tensor, + t_ctx: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Return next disease event after context. + + Returns: + dt_years: (B,) float, time to next disease in years; +inf if none + cause: (B,) long, disease id in [0,K) for next event; -1 if none + """ + B, L = event_seq.shape + b = torch.arange(B, device=event_seq.device) + t0 = time_seq[b, t_ctx] + + # Allow same-day events while excluding the context token itself. + # We rely on time-sorted sequences and select the FIRST valid future event by index. + idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) + future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) + idx_min = torch.where( + future, idxs, torch.full_like(idxs, L)).min(dim=1).values + + has = idx_min < L + t_next = torch.where(has, idx_min, torch.zeros_like(idx_min)) + + t_next_time = time_seq[b, t_next] + dt_days = t_next_time - t0 + dt_years = dt_days / DAYS_PER_YEAR + dt_years = torch.where( + has, dt_years, torch.full_like(dt_years, float("inf"))) + + cause_token = event_seq[b, t_next] + cause = (cause_token - 2).to(torch.long) + cause = torch.where(has, cause, torch.full_like(cause, -1)) + + return dt_years, cause + + +def multi_hot_ever_within_horizon( + event_seq: torch.Tensor, + time_seq: torch.Tensor, + t_ctx: torch.Tensor, + tau_years: float, + n_disease: int, +) -> torch.Tensor: + """Binary labels: disease k occurs within tau after context (any occurrence).""" + B, L = event_seq.shape + b = torch.arange(B, device=event_seq.device) + t0 = time_seq[b, t_ctx] + t1 = t0 + (tau_years * DAYS_PER_YEAR) + + idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) + # Include same-day events after context, exclude any token at/before context index. + in_window = ( + (idxs > t_ctx.unsqueeze(1)) + & (time_seq >= t0.unsqueeze(1)) + & (time_seq <= t1.unsqueeze(1)) + & (event_seq >= 2) + & (event_seq != 0) + ) + + if not in_window.any(): + return torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) + + b_idx, t_idx = in_window.nonzero(as_tuple=True) + disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) + + y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) + y[b_idx, disease_ids] = True + return y + + +def multi_hot_ever_after_context_anytime( + event_seq: torch.Tensor, + t_ctx: torch.Tensor, + n_disease: int, +) -> torch.Tensor: + """Binary labels: disease k occurs ANYTIME after the prediction context. + + This is Delphi2M-compatible for Task A case/control definition. + Same-day events are included as long as they occur after the context token index. + """ + B, L = event_seq.shape + idxs = torch.arange(L, device=event_seq.device).unsqueeze(0).expand(B, -1) + future = (idxs > t_ctx.unsqueeze(1)) & (event_seq >= 2) & (event_seq != 0) + + y = torch.zeros((B, n_disease), dtype=torch.bool, device=event_seq.device) + if not future.any(): + return y + + b_idx, t_idx = future.nonzero(as_tuple=True) + disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) + y[b_idx, disease_ids] = True + return y + + +def multi_hot_selected_causes_within_horizon( + event_seq: torch.Tensor, + time_seq: torch.Tensor, + t_ctx: torch.Tensor, + tau_years: float, + cause_ids: torch.Tensor, + n_disease: int, +) -> torch.Tensor: + """Labels for selected causes only: does cause k occur within tau after context?""" + B, L = event_seq.shape + device = event_seq.device + b = torch.arange(B, device=device) + t0 = time_seq[b, t_ctx] + t1 = t0 + (tau_years * DAYS_PER_YEAR) + + idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) + in_window = ( + (idxs > t_ctx.unsqueeze(1)) + & (time_seq >= t0.unsqueeze(1)) + & (time_seq <= t1.unsqueeze(1)) + & (event_seq >= 2) + & (event_seq != 0) + ) + + out = torch.zeros((B, cause_ids.numel()), dtype=torch.bool, device=device) + if not in_window.any(): + return out + + b_idx, t_idx = in_window.nonzero(as_tuple=True) + disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) + + # Filter to selected causes via a boolean membership mask over the global disease space. + selected = torch.zeros((int(n_disease),), dtype=torch.bool, device=device) + selected[cause_ids] = True + keep = selected[disease_ids] + if not keep.any(): + return out + + b_idx = b_idx[keep] + disease_ids = disease_ids[keep] + + # Map disease_id -> local index in cause_ids + # Build a lookup table (global disease space) where lookup[disease_id] = local_index + lookup = torch.full((int(n_disease),), -1, dtype=torch.long, device=device) + lookup[cause_ids] = torch.arange(cause_ids.numel(), device=device) + local = lookup[disease_ids] + out[b_idx, local] = True + return out + + +# ============================================================ +# CIF conversion +# ============================================================ + +def cifs_from_exponential_logits( + logits: torch.Tensor, + taus: Sequence[float], + eps: float = 1e-6, + return_survival: bool = False, +) -> torch.Tensor: + """Convert exponential cause-specific logits -> CIFs at taus. + + logits: (B, K) + returns: (B, K, H) or (cif, survival) if return_survival + """ + hazards = F.softplus(logits) + eps + total = hazards.sum(dim=1, keepdim=True) # (B,1) + + taus_t = torch.tensor(list(taus), device=logits.device, + dtype=hazards.dtype).view(1, 1, -1) + total_h = total.unsqueeze(-1) # (B,1,1) + + # (1 - exp(-Lambda * tau)) + one_minus_surv = 1.0 - torch.exp(-total_h * taus_t) + frac = hazards / torch.clamp(total, min=eps) + cif = frac.unsqueeze(-1) * one_minus_surv # (B,K,H) + # If total==0, set to 0 + cif = torch.where(total_h > 0, cif, torch.zeros_like(cif)) + + if not return_survival: + return cif + + survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H) + survival = torch.where(total.squeeze(1) > 0, survival, + torch.ones_like(survival)) + return cif, survival + + +def cifs_from_discrete_time_logits( + logits: torch.Tensor, + bin_edges: Sequence[float], + taus: Sequence[float], + return_survival: bool = False, +) -> torch.Tensor: + """Convert discrete-time CIF logits -> CIFs at taus. + + logits: (B, K+1, n_bins+1) + bin_edges: len=n_bins+1 (including 0 and inf) + taus: subset of finite bin edges (recommended) + + returns: (B, K, H) or (cif, survival) if return_survival + """ + if logits.ndim != 3: + raise ValueError("Expected logits shape (B, K+1, n_bins+1)") + + B, K_plus_1, n_bins_plus_1 = logits.shape + K = K_plus_1 - 1 + + edges = [float(x) for x in bin_edges] + # drop the 0 edge; bins correspond to intervals ending at edges[1:], excluding +inf + finite_edges = [e for e in edges[1:] if math.isfinite(e)] + n_bins = len(finite_edges) + + if n_bins_plus_1 != len(edges): + raise ValueError("logits last dim must match len(bin_edges)") + + probs = torch.softmax(logits, dim=1) # (B, K+1, n_bins+1) + + # use bins 1..n_bins (ignore bin 0, ignore +inf bin edge slot) + hazards = probs[:, :K, 1: 1 + n_bins] # (B,K,n_bins) + p_comp = probs[:, K, 1: 1 + n_bins] # (B,n_bins) + + # survival before each bin: S_prev[0]=1, S_prev[u]=prod_{v< u} p_comp[v] + ones = torch.ones((B, 1), device=logits.device, dtype=probs.dtype) + cum = torch.cumprod(p_comp, dim=1) + s_prev = torch.cat([ones, cum[:, :-1]], dim=1) # (B, n_bins) + + cif_bins = torch.cumsum(s_prev.unsqueeze( + 1) * hazards, dim=2) # (B,K,n_bins) + + # Robust mapping from tau -> edge index (floating-point safe). + # taus are expected to align with bin edges, but may differ slightly due to parsing/serialization. + finite_edges_arr = np.asarray(finite_edges, dtype=float) + tau_to_idx: List[int] = [] + for tau in taus: + tau_f = float(tau) + if not math.isfinite(tau_f): + raise ValueError("taus must be finite for discrete-time CIF") + diffs = np.abs(finite_edges_arr - tau_f) + j = int(np.argmin(diffs)) + if diffs[j] > 1e-6: + raise ValueError( + f"tau={tau_f} not close to any finite bin edge (min |edge-tau|={diffs[j]})" + ) + tau_to_idx.append(j) + + idx = torch.tensor(tau_to_idx, device=logits.device, dtype=torch.long) + cif = cif_bins.index_select(dim=2, index=idx) # (B,K,H) + + if not return_survival: + return cif + + # Survival at each horizon = prod_{u <= idx[h]} p_comp[u] + survival_bins = cum # (B,n_bins), cum[u] = prod_{v<=u} p_comp[v] + survival = survival_bins.index_select(dim=1, index=idx) # (B,H) + return cif, survival + + +# ============================================================ +# CIF integrity checks +# ============================================================ + +def check_cif_integrity( + cause_cif: np.ndarray, + horizons: Sequence[float], + *, + tol: float = 1e-6, + name: str = "", + strict: bool = False, + survival: Optional[np.ndarray] = None, +) -> Tuple[bool, List[str]]: + """Run basic sanity checks on CIF arrays. + + Args: + cause_cif: (N, K, H) + horizons: length H + tol: tolerance for inequalities + name: model name for messages + strict: if True, raise ValueError on first failure + survival: optional (N, H) survival values at the same horizons + + Returns: + (integrity_ok, notes) + """ + notes: List[str] = [] + model_tag = f"[{name}] " if name else "" + + def _fail(msg: str) -> None: + full = model_tag + msg + if strict: + raise ValueError(full) + print("WARNING:", full) + notes.append(msg) + + cif = np.asarray(cause_cif) + if cif.ndim != 3: + _fail(f"integrity: expected cause_cif ndim=3, got {cif.ndim}") + return False, notes + N, K, H = cif.shape + if H != len(horizons): + _fail( + f"integrity: horizon length mismatch (H={H}, len(horizons)={len(horizons)})") + + # (5) Finite + if not np.isfinite(cif).all(): + _fail("integrity: non-finite values (NaN/Inf) in cause_cif") + + # (1) Range + cmin = float(np.nanmin(cif)) + cmax = float(np.nanmax(cif)) + if cmin < -tol: + _fail(f"integrity: range min={cmin} < -tol={-tol}") + if cmax > 1.0 + tol: + _fail(f"integrity: range max={cmax} > 1+tol={1.0+tol}") + + # (2) Monotonicity in horizons (per n,k) + diffs = np.diff(cif, axis=2) + if diffs.size > 0: + if np.nanmin(diffs) < -tol: + _fail("integrity: monotonicity violated (found negative diff along horizons)") + + # (3) Probability mass: sum_k CIF <= 1 + mass = np.sum(cif, axis=1) # (N,H) + mass_max = float(np.nanmax(mass)) + if mass_max > 1.0 + tol: + _fail(f"integrity: probability mass exceeds 1 (max sum_k={mass_max})") + + # (4) Conservation with survival, if provided + if survival is None: + warn = "integrity: survival not provided; skipping conservation check" + if strict: + # still skip (requested behavior), but keep message for context + notes.append(warn) + else: + print("WARNING:", model_tag + warn) + notes.append(warn) + else: + s = np.asarray(survival, dtype=float) + if s.shape != (N, H): + _fail( + f"integrity: survival shape mismatch (got {s.shape}, expected {(N, H)})") + else: + recon = 1.0 - s + err = np.abs(recon - mass) + # Discrete-time should be very tight; exponential may accumulate slightly more numerical error. + tol_cons = max(float(tol), 1e-4) + if float(np.nanmax(err)) > tol_cons: + _fail( + f"integrity: conservation violated (max |(1-surv)-sum_cif|={float(np.nanmax(err))}, tol={tol_cons})") + + ok = len([n for n in notes if not n.endswith( + "skipping conservation check")]) == 0 + return ok, notes + + +# ============================================================ +# Metrics +# ============================================================ + +# --- Standard fast DeLong AUC variance + CI (ties handled via midranks) --- + +def compute_midrank(x: np.ndarray) -> np.ndarray: + x = np.asarray(x, dtype=float) + order = np.argsort(x) + z = x[order] + n = x.shape[0] + t = np.zeros(n, dtype=float) + i = 0 + while i < n: + j = i + while j < n and z[j] == z[i]: + j += 1 + t[i:j] = 0.5 * (i + j - 1) + 1.0 + i = j + out = np.empty(n, dtype=float) + out[order] = t + return out + + +def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]: + """Fast DeLong method for computing AUC covariance. + + predictions_sorted_transposed: shape (n_classifiers, n_examples) with positive examples first. + """ + preds = np.asarray(predictions_sorted_transposed, dtype=float) + m = int(label_1_count) + n = int(preds.shape[1] - m) + if m <= 0 or n <= 0: + return np.array([float("nan")]), np.array([[float("nan")]]) + + pos = preds[:, :m] + neg = preds[:, m:] + + tx = np.array([compute_midrank(x) for x in pos]) + ty = np.array([compute_midrank(x) for x in neg]) + tz = np.array([compute_midrank(x) for x in preds]) + + aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n) + + v01 = (tz[:, :m] - tx) / n + v10 = 1.0 - (tz[:, m:] - ty) / m + + if v01.shape[0] > 1: + sx = np.cov(v01) + sy = np.cov(v10) + else: + # Single-classifier case: compute row-wise variance (do not flatten). + var_v01 = float(np.var(v01, axis=1, ddof=1)[0]) + var_v10 = float(np.var(v10, axis=1, ddof=1)[0]) + sx = np.array([[var_v01]]) + sy = np.array([[var_v10]]) + delong_cov = sx / m + sy / n + return aucs, delong_cov + + +def calc_auc_variance(ground_truth: np.ndarray, predictions: np.ndarray) -> Tuple[float, float]: + y = np.asarray(ground_truth, dtype=int) + p = np.asarray(predictions, dtype=float) + if y.ndim != 1 or p.ndim != 1 or y.shape[0] != p.shape[0]: + raise ValueError("calc_auc_variance expects 1D arrays of equal length") + + m = int(np.sum(y == 1)) + n = int(np.sum(y == 0)) + if m == 0 or n == 0: + return float("nan"), float("nan") + + order = np.argsort(-y) # positives first + preds_sorted = p[order] + aucs, cov = fastDeLong(preds_sorted[np.newaxis, :], m) + auc = float(aucs[0]) + var = float(cov[0, 0]) + return auc, var + + +def delong_ci(ground_truth: np.ndarray, predictions: np.ndarray, alpha: float = 0.95) -> Tuple[float, float, float]: + """Return (auc, ci_low, ci_high) using DeLong variance and normal CI.""" + auc, var = calc_auc_variance(ground_truth, predictions) + if not np.isfinite(var) or var <= 0: + print("WARNING: DeLong variance is non-positive or NaN; CI set to NaN") + return float(auc), float("nan"), float("nan") + + sd = math.sqrt(var) + z = statistics.NormalDist().inv_cdf(1.0 - (1.0 - float(alpha)) / 2.0) + lo = max(0.0, auc - z * sd) + hi = min(1.0, auc + z * sd) + return float(auc), float(lo), float(hi) + + +def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float: + """Rank-based ROC AUC via Mann–Whitney U statistic (ties handled by midranks). + + Returns NaN for degenerate labels. + """ + y = np.asarray(y_true, dtype=int) + s = np.asarray(y_score, dtype=float) + if y.ndim != 1 or s.ndim != 1 or y.shape[0] != s.shape[0]: + raise ValueError("roc_auc_rank expects 1D arrays of equal length") + m = int(np.sum(y == 1)) + n = int(np.sum(y == 0)) + if m == 0 or n == 0: + return float("nan") + + ranks = compute_midrank(s) + sum_pos = float(np.sum(ranks[y == 1])) + auc = (sum_pos - m * (m + 1) / 2.0) / (m * n) + return float(auc) + + +def bootstrap_auc_ci( + scores: np.ndarray, + labels: np.ndarray, + n_bootstrap: int, + alpha: float = 0.95, + seed: int = 0, +) -> Tuple[float, float, float]: + """Bootstrap CI for ROC AUC (percentile).""" + rng = np.random.default_rng(int(seed)) + scores = np.asarray(scores, dtype=float) + labels = np.asarray(labels, dtype=int) + n = labels.shape[0] + if n == 0 or np.all(labels == labels[0]): + print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN") + return float("nan"), float("nan"), float("nan") + + auc_full = roc_auc_rank(labels, scores) + if not np.isfinite(auc_full): + print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN") + return float("nan"), float("nan"), float("nan") + + aucs: List[float] = [] + for _ in range(int(n_bootstrap)): + idx = rng.integers(0, n, size=n) + yb = labels[idx] + if np.all(yb == yb[0]): + continue + pb = scores[idx] + auc = roc_auc_rank(yb, pb) + if np.isfinite(auc): + aucs.append(float(auc)) + + if len(aucs) < 10: + print("WARNING: bootstrap AUC CI has too few valid resamples; CI set to NaN") + return float(auc_full), float("nan"), float("nan") + + lo_q = (1.0 - float(alpha)) / 2.0 + hi_q = 1.0 - lo_q + lo = float(np.quantile(aucs, lo_q)) + hi = float(np.quantile(aucs, hi_q)) + return float(auc_full), lo, hi + + +def brier_score(p: np.ndarray, y: np.ndarray) -> float: + p = np.asarray(p, dtype=float) + y = np.asarray(y, dtype=float) + return float(np.mean((p - y) ** 2)) + + +def calibration_deciles(p: np.ndarray, y: np.ndarray, n_bins: int = 10) -> Dict[str, Any]: + p = np.asarray(p, dtype=float) + y = np.asarray(y, dtype=float) + + # guard + if p.size == 0: + return {"bins": [], "ece": float("nan"), "ici": float("nan")} + + edges = np.quantile(p, np.linspace(0.0, 1.0, n_bins + 1)) + # make strictly increasing where possible + edges[0] = -np.inf + edges[-1] = np.inf + + bins = [] + ece = 0.0 + ici_accum = 0.0 + n = p.shape[0] + + for i in range(n_bins): + mask = (p > edges[i]) & (p <= edges[i + 1]) + if not np.any(mask): + continue + p_mean = float(np.mean(p[mask])) + y_mean = float(np.mean(y[mask])) + frac = float(np.mean(mask)) + bins.append({"bin": i, "p_mean": p_mean, + "y_mean": y_mean, "n": int(mask.sum())}) + ece += frac * abs(p_mean - y_mean) + ici_accum += abs(p_mean - y_mean) + + ici = ici_accum / max(len(bins), 1) + return {"bins": bins, "ece": float(ece), "ici": float(ici)} + + +def count_ever_after_context_anytime( + loader: DataLoader, + offset_years: float, + n_disease: int, + device: str, +) -> Tuple[np.ndarray, int]: + """Count per-person ever-occurrence for each disease after the prediction context. + + Returns counts[k] = number of individuals with disease k at least once after context. + """ + counts = torch.zeros((n_disease,), dtype=torch.long, device=device) + n_total_eval = 0 + for batch in loader: + event_seq, time_seq, cont_feats, cate_feats, sexes = batch + event_seq = event_seq.to(device) + time_seq = time_seq.to(device) + + keep, t_ctx, _ = select_context_indices( + event_seq, time_seq, offset_years) + if not keep.any(): + continue + + n_total_eval += int(keep.sum().item()) + event_seq = event_seq[keep] + t_ctx = t_ctx[keep] + + B, L = event_seq.shape + idxs = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) + future = (idxs > t_ctx.unsqueeze(1)) & ( + event_seq >= 2) & (event_seq != 0) + if not future.any(): + continue + + b_idx, t_idx = future.nonzero(as_tuple=True) + disease_ids = (event_seq[b_idx, t_idx] - 2).to(torch.long) + + # unique per (person, disease) to count per-person ever-occurrence + key = b_idx.to(torch.long) * int(n_disease) + disease_ids + uniq = torch.unique(key) + uniq_disease = uniq % int(n_disease) + counts.scatter_add_(0, uniq_disease, torch.ones_like( + uniq_disease, dtype=torch.long)) + + return counts.detach().cpu().numpy(), int(n_total_eval) + + +# ============================================================ +# Evaluation core +# ============================================================ + +def instantiate_model_and_head( + cfg: Dict[str, Any], + dataset: HealthDataset, + device: str, + checkpoint_path: str = "", +) -> Tuple[torch.nn.Module, torch.nn.Module, str, Sequence[float]]: + model_type = str(cfg["model_type"]) + loss_type = str(cfg["loss_type"]) + + if loss_type == "exponential": + out_dims = [dataset.n_disease] + elif loss_type == "discrete_time_cif": + bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) + out_dims = [dataset.n_disease + 1, len(bin_edges)] + else: + raise ValueError(f"Unsupported loss_type for evaluation: {loss_type}") + + if model_type == "delphi_fork": + backbone = DelphiFork( + n_disease=dataset.n_disease, + n_tech_tokens=2, + n_embd=int(cfg["n_embd"]), + n_head=int(cfg["n_head"]), + n_layer=int(cfg["n_layer"]), + pdrop=float(cfg.get("pdrop", 0.0)), + age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), + n_cont=dataset.n_cont, + n_cate=dataset.n_cate, + cate_dims=dataset.cate_dims, + ).to(device) + elif model_type == "sap_delphi": + # Config key compatibility: prefer pretrained_emb_path, fallback to pretrained_emd_path. + emb_path = cfg.get("pretrained_emb_path", None) + if emb_path in {"", None}: + emb_path = cfg.get("pretrained_emd_path", None) + if emb_path in {"", None}: + run_dir = os.path.dirname(os.path.abspath( + checkpoint_path)) if checkpoint_path else "" + print( + f"WARNING: SapDelphi pretrained embedding path missing in config " + f"(expected 'pretrained_emb_path' or 'pretrained_emd_path'). " + f"checkpoint={checkpoint_path} run_dir={run_dir}" + ) + backbone = SapDelphi( + n_disease=dataset.n_disease, + n_tech_tokens=2, + n_embd=int(cfg["n_embd"]), + n_head=int(cfg["n_head"]), + n_layer=int(cfg["n_layer"]), + pdrop=float(cfg.get("pdrop", 0.0)), + age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), + n_cont=dataset.n_cont, + n_cate=dataset.n_cate, + cate_dims=dataset.cate_dims, + pretrained_weights_path=emb_path, + freeze_embeddings=True, + ).to(device) + else: + raise ValueError(f"Unsupported model_type: {model_type}") + + head = SimpleHead(n_embd=int(cfg["n_embd"]), out_dims=out_dims).to(device) + bin_edges = cfg.get("bin_edges", DEFAULT_BIN_EDGES) + return backbone, head, loss_type, bin_edges + + +@torch.no_grad() +def predict_cifs_for_model( + backbone: torch.nn.Module, + head: torch.nn.Module, + loss_type: str, + bin_edges: Sequence[float], + loader: DataLoader, + device: str, + offset_years: float, + eval_horizons: Sequence[float], + top_cause_ids: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Run model and produce: + + Returns: + allcause_risk: (N,H) + cause_cif: (N, topK, H) + cif_full: (N, K, H) + survival: (N, H) + sex: (N,) + y_allcause_tau: (N,H) + y_cause_ever_anytime: (N, topK) + y_cause_within_tau: (N, topK, H) + y_cause_within_tau_max: (N, topK) + + NOTE: + - y_cause_ever_anytime is Delphi2M-compatible case/control label. + - y_cause_within_tau_* corresponds to within-horizon labels (kept for legacy/secondary AUC). + """ + backbone.eval() + head.eval() + + # We will accumulate in CPU lists, then concat. + allcause_list: List[np.ndarray] = [] + cause_cif_list: List[np.ndarray] = [] + cif_full_list: List[np.ndarray] = [] + survival_list: List[np.ndarray] = [] + sex_list: List[np.ndarray] = [] + y_all_list: List[np.ndarray] = [] + y_cause_ever_any_list: List[np.ndarray] = [] + y_cause_within_list: List[np.ndarray] = [] + y_cause_within_tau_max_list: List[np.ndarray] = [] + + tau_max = float(max(eval_horizons)) + top_cause_ids_t = torch.tensor( + top_cause_ids, dtype=torch.long, device=device) + + # Efficiency: pre-create horizons tensor once per model (on device) and vectorize comparisons. + eval_horizons_t = torch.tensor( + list(eval_horizons), device=device, dtype=torch.float32).view(1, -1) + + for batch in loader: + event_seq, time_seq, cont_feats, cate_feats, sexes = batch + event_seq = event_seq.to(device) + time_seq = time_seq.to(device) + cont_feats = cont_feats.to(device) + cate_feats = cate_feats.to(device) + sexes = sexes.to(device) + + keep, t_ctx, _ = select_context_indices( + event_seq, time_seq, offset_years) + if not keep.any(): + continue + + # filter batch + event_seq = event_seq[keep] + time_seq = time_seq[keep] + cont_feats = cont_feats[keep] + cate_feats = cate_feats[keep] + sexes_k = sexes[keep] + t_ctx = t_ctx[keep] + + h = backbone(event_seq, time_seq, sexes_k, + cont_feats, cate_feats) # (B,L,D) + b = torch.arange(h.size(0), device=device) + c = h[b, t_ctx] # (B,D) + logits = head(c) + + if loss_type == "exponential": + cif_full, survival = cifs_from_exponential_logits( + logits, eval_horizons, return_survival=True) # (B,K,H), (B,H) + elif loss_type == "discrete_time_cif": + cif_full, survival = cifs_from_discrete_time_logits( + # (B,K,H), (B,H) + logits, bin_edges, eval_horizons, return_survival=True) + else: + raise ValueError(f"Unsupported loss_type: {loss_type}") + + allcause = cif_full.sum(dim=1) # (B,H) + cause_cif = cif_full.index_select( + dim=1, index=top_cause_ids_t) # (B,topK,H) + + # outcomes + dt_next, _cause_next = next_event_after_context( + event_seq, time_seq, t_ctx) + y_all = (dt_next.view(-1, 1) <= eval_horizons_t).to(torch.float32) + + # Delphi2M-compatible ever label (does not depend on horizon) + y_ever_any = multi_hot_ever_after_context_anytime( + event_seq=event_seq, + t_ctx=t_ctx, + n_disease=int(cif_full.size(1)), + ) + y_ever_any_top = y_ever_any.index_select( + dim=1, index=top_cause_ids_t).to(torch.float32) + + # Within-horizon labels for cause-specific CIF quality + legacy AUC + n_disease = int(cif_full.size(1)) + y_within_top = torch.stack( + [ + multi_hot_selected_causes_within_horizon( + event_seq=event_seq, + time_seq=time_seq, + t_ctx=t_ctx, + tau_years=float(tau), + cause_ids=top_cause_ids_t, + n_disease=n_disease, + ).to(torch.float32) + for tau in eval_horizons + ], + dim=2, + ) # (B,topK,H) + y_within_tau_max_top = multi_hot_selected_causes_within_horizon( + event_seq=event_seq, + time_seq=time_seq, + t_ctx=t_ctx, + tau_years=tau_max, + cause_ids=top_cause_ids_t, + n_disease=n_disease, + ).to(torch.float32) + + allcause_list.append(allcause.detach().cpu().numpy()) + cause_cif_list.append(cause_cif.detach().cpu().numpy()) + cif_full_list.append(cif_full.detach().cpu().numpy()) + survival_list.append(survival.detach().cpu().numpy()) + sex_list.append(sexes_k.detach().cpu().numpy()) + y_all_list.append(y_all.detach().cpu().numpy()) + y_cause_ever_any_list.append(y_ever_any_top.detach().cpu().numpy()) + y_cause_within_list.append(y_within_top.detach().cpu().numpy()) + y_cause_within_tau_max_list.append( + y_within_tau_max_top.detach().cpu().numpy()) + + if not allcause_list: + raise RuntimeError( + "No valid samples for evaluation (all batches filtered out by offset).") + + allcause_risk = np.concatenate(allcause_list, axis=0) + cause_cif = np.concatenate(cause_cif_list, axis=0) + cif_full = np.concatenate(cif_full_list, axis=0) + survival = np.concatenate(survival_list, axis=0) + sex = np.concatenate(sex_list, axis=0) + y_allcause = np.concatenate(y_all_list, axis=0) + y_cause_ever_any = np.concatenate(y_cause_ever_any_list, axis=0) + y_cause_within = np.concatenate(y_cause_within_list, axis=0) + y_cause_within_tau_max = np.concatenate(y_cause_within_tau_max_list, axis=0) + + return allcause_risk, cause_cif, cif_full, survival, sex, y_allcause, y_cause_ever_any, y_cause_within, y_cause_within_tau_max + + +def pick_top_causes(y_ever: np.ndarray, top_k: int) -> np.ndarray: + counts = y_ever.sum(axis=0) + order = np.argsort(-counts) + order = order[counts[order] > 0] + return order[:top_k] + + +def evaluate_one_model( + model_name: str, + allcause_risk: np.ndarray, + cause_cif: np.ndarray, + sex: np.ndarray, + y_allcause: np.ndarray, + y_cause_ever_anytime: np.ndarray, + y_cause_within_tau: np.ndarray, + y_cause_within_tau_max: np.ndarray, + eval_horizons: Sequence[float], + top_cause_ids: np.ndarray, + out_rows: List[Dict[str, Any]], + calib_rows: List[Dict[str, Any]], + auc_ci_method: str, + bootstrap_n: int, + n_calib_bins: int = 10, +) -> None: + H = len(eval_horizons) + + # Task B (all-cause): Brier + AUC + calibration per horizon + for h_i, tau in enumerate(eval_horizons): + p = allcause_risk[:, h_i] + y = y_allcause[:, h_i] + + out_rows.append( + { + "model_name": model_name, + "metric_name": "allcause_brier", + "horizon": float(tau), + "cause": "", + "value": brier_score(p, y), + "ci_low": "", + "ci_high": "", + } + ) + + if auc_ci_method == "none": + auc, lo, hi = float("nan"), float("nan"), float("nan") + auc = float("nan") + elif auc_ci_method == "bootstrap": + auc, lo, hi = bootstrap_auc_ci( + p, y, n_bootstrap=bootstrap_n, alpha=0.95) + else: + auc, lo, hi = delong_ci(y, p, alpha=0.95) + out_rows.append( + { + "model_name": model_name, + "metric_name": "allcause_auc", + "horizon": float(tau), + "cause": "", + "value": auc, + "ci_low": lo, + "ci_high": hi, + } + ) + + cal = calibration_deciles(p, y, n_bins=n_calib_bins) + out_rows.append( + { + "model_name": model_name, + "metric_name": "allcause_ece", + "horizon": float(tau), + "cause": "", + "value": cal["ece"], + "ci_low": "", + "ci_high": "", + } + ) + out_rows.append( + { + "model_name": model_name, + "metric_name": "allcause_ici", + "horizon": float(tau), + "cause": "", + "value": cal["ici"], + "ci_low": "", + "ci_high": "", + } + ) + + # Write calibration bins into a separate CSV (always for all-cause). + for binfo in cal.get("bins", []): + calib_rows.append( + { + "model_name": model_name, + "task": "all_cause", + "horizon": float(tau), + "cause_id": -1, + "bin_index": int(binfo["bin"]), + "p_mean": float(binfo["p_mean"]), + "y_mean": float(binfo["y_mean"]), + "n_in_bin": int(binfo["n"]), + } + ) + + # Stratification by sex + for s_val in [0, 1]: + m = sex == s_val + if np.sum(m) < 10: + continue + p_s = p[m] + y_s = y[m] + if auc_ci_method == "none": + auc_s, lo_s, hi_s = float("nan"), float("nan"), float("nan") + elif auc_ci_method == "bootstrap": + auc_s, lo_s, hi_s = bootstrap_auc_ci( + p_s, y_s, n_bootstrap=bootstrap_n, alpha=0.95) + else: + auc_s, lo_s, hi_s = delong_ci(y_s, p_s, alpha=0.95) + out_rows.append( + { + "model_name": model_name, + "metric_name": f"allcause_auc_sex{s_val}", + "horizon": float(tau), + "cause": "", + "value": auc_s, + "ci_low": lo_s, + "ci_high": hi_s, + } + ) + + # Task A (Delphi2M-compatible discrimination): per-cause AUC with EVER labels + # case/control is defined by whether the disease appears ANYTIME after context. + tau_max = float(max(eval_horizons)) + p_tau_max = cause_cif[:, :, -1] # (N, topK) + + for j, cause_id in enumerate(top_cause_ids.tolist()): + yk = y_cause_ever_anytime[:, j] + pk = p_tau_max[:, j] + if auc_ci_method == "none": + auc, lo, hi = float("nan"), float("nan"), float("nan") + elif auc_ci_method == "bootstrap": + auc, lo, hi = bootstrap_auc_ci( + pk, yk, n_bootstrap=bootstrap_n, alpha=0.95) + else: + auc, lo, hi = delong_ci(yk, pk, alpha=0.95) + out_rows.append( + { + "model_name": model_name, + "metric_name": "cause_auc_ever", + "horizon": tau_max, + "cause": int(cause_id), + "value": auc, + "ci_low": lo, + "ci_high": hi, + } + ) + + # Keep the existing tau-window AUC as a separate metric (do not remove). + for j, cause_id in enumerate(top_cause_ids.tolist()): + yk = y_cause_within_tau_max[:, j] + pk = p_tau_max[:, j] + if auc_ci_method == "none": + auc, lo, hi = float("nan"), float("nan"), float("nan") + elif auc_ci_method == "bootstrap": + auc, lo, hi = bootstrap_auc_ci( + pk, yk, n_bootstrap=bootstrap_n, alpha=0.95) + else: + auc, lo, hi = delong_ci(yk, pk, alpha=0.95) + out_rows.append( + { + "model_name": model_name, + "metric_name": "cause_auc", + "horizon": tau_max, + "cause": int(cause_id), + "value": auc, + "ci_low": lo, + "ci_high": hi, + } + ) + + # Task B additions: cause-specific Brier + calibration curves at tau=3.84 and 10.0 + tau_targets = [3.84, 10.0] + horizon_to_idx = {float(t): i for i, t in enumerate( + [float(x) for x in eval_horizons])} + for tau in tau_targets: + if float(tau) not in horizon_to_idx: + continue + h_idx = horizon_to_idx[float(tau)] + p_tau = cause_cif[:, :, h_idx] # (N, topK) + y_tau = y_cause_within_tau[:, :, h_idx] # (N, topK) + + for j, cause_id in enumerate(top_cause_ids.tolist()): + p = p_tau[:, j] + y = y_tau[:, j] + + out_rows.append( + { + "model_name": model_name, + "metric_name": "cause_brier", + "horizon": float(tau), + "cause": int(cause_id), + "value": brier_score(p, y), + "ci_low": "", + "ci_high": "", + } + ) + + cal = calibration_deciles(p, y) + out_rows.append( + { + "model_name": model_name, + "metric_name": "cause_ece", + "horizon": float(tau), + "cause": int(cause_id), + "value": cal["ece"], + "ci_low": "", + "ci_high": "", + } + ) + out_rows.append( + { + "model_name": model_name, + "metric_name": "cause_ici", + "horizon": float(tau), + "cause": int(cause_id), + "value": cal["ici"], + "ci_low": "", + "ci_high": "", + } + ) + + # Write cause calibration bins into separate CSV only for tau targets. + for binfo in cal.get("bins", []): + calib_rows.append( + { + "model_name": model_name, + "task": "cause_k", + "horizon": float(tau), + "cause_id": int(cause_id), + "bin_index": int(binfo["bin"]), + "p_mean": float(binfo["p_mean"]), + "y_mean": float(binfo["y_mean"]), + "n_in_bin": int(binfo["n"]), + } + ) + + +def write_calibration_bins_csv(path: str, rows: List[Dict[str, Any]]) -> None: + fieldnames = [ + "model_name", + "task", + "horizon", + "cause_id", + "bin_index", + "p_mean", + "y_mean", + "n_in_bin", + ] + with open(path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for r in rows: + w.writerow(r) + + +def write_results_csv(path: str, rows: List[Dict[str, Any]]) -> None: + fieldnames = [ + "model_name", + "metric_name", + "horizon", + "cause", + "value", + "ci_low", + "ci_high", + ] + with open(path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for r in rows: + w.writerow(r) + + +def _make_eval_tag(split: str, offset_years: float) -> str: + """Short tag for filenames written into run directories.""" + off = f"{float(offset_years):.4f}".rstrip("0").rstrip(".") + return f"{split}_offset{off}y" + + +def main() -> int: + ap = argparse.ArgumentParser( + description="Unified downstream evaluation via CIFs") + ap.add_argument("--models_json", type=str, required=True, + help="Path to models list JSON") + ap.add_argument("--data_prefix", type=str, + default="ukb", help="Dataset prefix") + ap.add_argument("--split", type=str, default="test", + choices=["train", "val", "test", "all"], help="Which split to evaluate") + ap.add_argument("--offset_years", type=float, default=0.5, + help="Anti-leakage offset (years)") + ap.add_argument("--eval_horizons", type=float, + nargs="*", default=DEFAULT_EVAL_HORIZONS) + ap.add_argument("--top_k_causes", type=int, default=50) + ap.add_argument("--batch_size", type=int, default=128) + ap.add_argument("--num_workers", type=int, default=0) + ap.add_argument("--seed", type=int, default=123) + ap.add_argument("--device", type=str, + default="cuda" if torch.cuda.is_available() else "cpu") + ap.add_argument("--out_csv", type=str, default="eval_results.csv") + ap.add_argument("--out_meta_json", type=str, default="eval_meta.json") + + # Integrity checks + ap.add_argument("--integrity_strict", action="store_true", default=False) + ap.add_argument("--integrity_tol", type=float, default=1e-6) + + # AUC CI methods + ap.add_argument( + "--auc_ci_method", + type=str, + default="delong", + choices=["delong", "bootstrap", "none"], + ) + ap.add_argument("--bootstrap_n", type=int, default=2000) + args = ap.parse_args() + + set_deterministic(args.seed) + + specs = load_models_json(args.models_json) + if not specs: + raise ValueError("No models provided") + + # Determine top-K causes from the evaluation split only (model-agnostic). + first_cfg = load_train_config_for_checkpoint(specs[0].checkpoint_path) + cov_list = None if _parse_bool(first_cfg.get("full_cov", False)) else [ + "bmi", "smoking", "alcohol"] + dataset_for_top = HealthDataset( + data_prefix=args.data_prefix, covariate_list=cov_list) + subset_for_top = build_eval_subset( + dataset_for_top, + train_ratio=float(first_cfg.get("train_ratio", 0.7)), + val_ratio=float(first_cfg.get("val_ratio", 0.15)), + seed=int(first_cfg.get("random_seed", 42)), + split=args.split, + ) + loader_top = DataLoader( + subset_for_top, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + collate_fn=health_collate_fn, + ) + + counts, n_total_eval = count_ever_after_context_anytime( + loader=loader_top, + offset_years=args.offset_years, + n_disease=dataset_for_top.n_disease, + device=args.device, + ) + order = np.argsort(-counts) + order = order[counts[order] > 0] + top_cause_ids = order[: args.top_k_causes] + + # Record top-cause counts under Delphi2M-compatible EVER label. + top_causes_meta: List[Dict[str, Any]] = [] + for k in top_cause_ids.tolist(): + n_case = int(counts[int(k)]) + top_causes_meta.append( + { + "cause_id": int(k), + "n_case_ever": n_case, + "n_control_ever": int(n_total_eval - n_case), + "n_total_eval": int(n_total_eval), + } + ) + + rows: List[Dict[str, Any]] = [] + calib_rows: List[Dict[str, Any]] = [] + + # Track per-model integrity status for meta JSON. + integrity_meta: Dict[str, Any] = {} + + # Evaluate each model + for spec in specs: + run_dir = os.path.dirname(os.path.abspath(spec.checkpoint_path)) + tag = _make_eval_tag(args.split, float(args.offset_years)) + + # Remember list offsets so we can write per-model slices to the model's run_dir. + rows_start = len(rows) + calib_start = len(calib_rows) + + cfg = load_train_config_for_checkpoint(spec.checkpoint_path) + + cov_list = None if _parse_bool(cfg.get("full_cov", False)) else [ + "bmi", "smoking", "alcohol"] + dataset = HealthDataset( + data_prefix=args.data_prefix, covariate_list=cov_list) + subset = build_eval_subset( + dataset, + train_ratio=float(cfg.get("train_ratio", 0.7)), + val_ratio=float(cfg.get("val_ratio", 0.15)), + seed=int(cfg.get("random_seed", 42)), + split=args.split, + ) + loader = DataLoader( + subset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + collate_fn=health_collate_fn, + ) + + backbone, head, loss_type, bin_edges = instantiate_model_and_head( + cfg, dataset, args.device, checkpoint_path=spec.checkpoint_path) + ckpt = torch.load(spec.checkpoint_path, map_location=args.device) + backbone.load_state_dict(ckpt["model_state_dict"], strict=True) + head.load_state_dict(ckpt["head_state_dict"], strict=True) + + ( + allcause_risk, + cause_cif, + cif_full, + survival, + sex, + y_allcause, + y_cause_ever_anytime, + y_cause_within_tau, + y_cause_within_tau_max, + ) = predict_cifs_for_model( + backbone, + head, + loss_type, + bin_edges, + loader, + args.device, + args.offset_years, + args.eval_horizons, + top_cause_ids, + ) + + # CIF integrity checks before metrics. + integrity_ok, integrity_notes = check_cif_integrity( + cif_full, + args.eval_horizons, + tol=float(args.integrity_tol), + name=spec.name, + strict=bool(args.integrity_strict), + survival=survival, + ) + integrity_meta[spec.name] = { + "integrity_ok": bool(integrity_ok), + "integrity_notes": integrity_notes, + } + + evaluate_one_model( + model_name=spec.name, + allcause_risk=allcause_risk, + cause_cif=cause_cif, + sex=sex, + y_allcause=y_allcause, + y_cause_ever_anytime=y_cause_ever_anytime, + y_cause_within_tau=y_cause_within_tau, + y_cause_within_tau_max=y_cause_within_tau_max, + eval_horizons=args.eval_horizons, + top_cause_ids=top_cause_ids, + out_rows=rows, + calib_rows=calib_rows, + auc_ci_method=str(args.auc_ci_method), + bootstrap_n=int(args.bootstrap_n), + ) + + # Optionally write top-cause counts into the main results CSV as metric rows. + for tc in top_causes_meta: + rows.append( + { + "model_name": spec.name, + "metric_name": "topcause_n_case_ever", + "horizon": "", + "cause": int(tc["cause_id"]), + "value": int(tc["n_case_ever"]), + "ci_low": "", + "ci_high": "", + } + ) + rows.append( + { + "model_name": spec.name, + "metric_name": "topcause_n_control_ever", + "horizon": "", + "cause": int(tc["cause_id"]), + "value": int(tc["n_control_ever"]), + "ci_low": "", + "ci_high": "", + } + ) + rows.append( + { + "model_name": spec.name, + "metric_name": "topcause_n_total_eval", + "horizon": "", + "cause": int(tc["cause_id"]), + "value": int(tc["n_total_eval"]), + "ci_low": "", + "ci_high": "", + } + ) + + # Write per-model results into the model's run directory. + model_rows = rows[rows_start:] + model_calib_rows = calib_rows[calib_start:] + model_out_csv = os.path.join(run_dir, f"eval_results_{tag}.csv") + model_calib_csv = os.path.join(run_dir, f"calibration_bins_{tag}.csv") + model_meta_json = os.path.join(run_dir, f"eval_meta_{tag}.json") + + write_results_csv(model_out_csv, model_rows) + write_calibration_bins_csv(model_calib_csv, model_calib_rows) + + model_meta = { + "model_name": spec.name, + "checkpoint_path": spec.checkpoint_path, + "run_dir": run_dir, + "split": args.split, + "offset_years": args.offset_years, + "eval_horizons": [float(x) for x in args.eval_horizons], + "top_k_causes": int(args.top_k_causes), + "top_cause_ids": top_cause_ids.tolist(), + "top_causes": top_causes_meta, + "integrity": {spec.name: integrity_meta.get(spec.name, {})}, + "paths": { + "results_csv": model_out_csv, + "calibration_bins_csv": model_calib_csv, + }, + } + with open(model_meta_json, "w") as f: + json.dump(model_meta, f, indent=2) + + print(f"Wrote per-model results to {model_out_csv}") + + write_results_csv(args.out_csv, rows) + + # Write calibration curve points to a separate CSV. + out_dir = os.path.dirname(os.path.abspath(args.out_csv)) or "." + calib_csv_path = os.path.join(out_dir, "calibration_bins.csv") + write_calibration_bins_csv(calib_csv_path, calib_rows) + + meta = { + "split": args.split, + "offset_years": args.offset_years, + "eval_horizons": [float(x) for x in args.eval_horizons], + "top_k_causes": int(args.top_k_causes), + "top_cause_ids": top_cause_ids.tolist(), + "top_causes": top_causes_meta, + "integrity": integrity_meta, + "notes": { + "task_a_label": "Delphi2M-compatible: disease occurs ANYTIME after context (ever in remaining sequence)", + "task_a_legacy_label": "Secondary: disease occurs within tau_max after context", + "task_b_label": "all-cause event within horizon (equivalent to next disease event within horizon)", + "warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.", + }, + } + with open(args.out_meta_json, "w") as f: + json.dump(meta, f, indent=2) + + print(f"Wrote {args.out_csv} with {len(rows)} rows") + print(f"Wrote {calib_csv_path} with {len(calib_rows)} rows") + print(f"Wrote {args.out_meta_json}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/models_eval_example.json b/models_eval_example.json new file mode 100644 index 0000000..462914e --- /dev/null +++ b/models_eval_example.json @@ -0,0 +1,30 @@ +[ + { + "name": "delphi_fork_discrete_time_cif_fullcov", + "model_type": "delphi_fork", + "loss_type": "discrete_time_cif", + "full_cov": true, + "checkpoint_path": "runs/delphi_fork_discrete_time_cif_sinusoidal_fullcov_20260109-222502/best_model.pt" + }, + { + "name": "delphi_fork_exponential_fullcov", + "model_type": "delphi_fork", + "loss_type": "exponential", + "full_cov": true, + "checkpoint_path": "runs/delphi_fork_exponential_sinusoidal_fullcov_20260109-222502/best_model.pt" + }, + { + "name": "sap_delphi_discrete_time_cif_fullcov", + "model_type": "sap_delphi", + "loss_type": "discrete_time_cif", + "full_cov": true, + "checkpoint_path": "runs/sap_delphi_discrete_time_cif_sinusoidal_fullcov_20260109-222502/best_model.pt" + }, + { + "name": "sap_delphi_exponential_fullcov", + "model_type": "sap_delphi", + "loss_type": "exponential", + "full_cov": true, + "checkpoint_path": "runs/sap_delphi_exponential_sinusoidal_fullcov_20260109-222502/best_model.pt" + } +] \ No newline at end of file