Files
DeepHealth/evaluation_age_time_dependent.py

853 lines
32 KiB
Python

from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
import torch
try:
from tqdm import tqdm
except Exception: # pragma: no cover
def tqdm(x, **kwargs):
return x
from utils import (
multi_hot_ever_within_horizon,
multi_hot_selected_causes_within_horizon,
sample_context_in_fixed_age_bin,
)
from torch_metrics import compute_binary_metrics_torch
def _nanmean_np(x: np.ndarray, axis: int = 0) -> np.ndarray:
with np.errstate(invalid="ignore"):
return np.nanmean(x, axis=axis)
def _nanstd_np_ddof1(x: np.ndarray, axis: int = 0) -> np.ndarray:
"""NaN-aware sample std (ddof=1), matching pandas std() semantics."""
x = np.asarray(x, dtype=float)
mask = np.isfinite(x)
cnt = mask.sum(axis=axis)
# mean over finite entries
x0 = np.where(mask, x, 0.0)
mean = x0.sum(axis=axis) / np.maximum(cnt, 1)
# sum of squared deviations over finite entries
dev2 = np.where(mask, (x - np.expand_dims(mean, axis=axis)) ** 2, 0.0)
ss = dev2.sum(axis=axis)
denom = cnt - 1
out = np.sqrt(ss / np.maximum(denom, 1))
out = np.where(denom > 0, out, np.nan)
return out
def _weighted_mean_np(x: np.ndarray, w: np.ndarray, axis: int = 0) -> np.ndarray:
"""NaN-aware weighted mean.
Only bins with finite x contribute to both numerator and denominator.
If denom==0 -> NaN.
"""
x = np.asarray(x, dtype=float)
w = np.asarray(w, dtype=float)
if axis != 0:
raise ValueError("_weighted_mean_np currently supports axis=0 only")
# Broadcast weights along trailing dims of x.
while w.ndim < x.ndim:
w = w[..., None]
w = np.broadcast_to(w, x.shape)
mask = np.isfinite(x)
num = np.where(mask, x * w, 0.0).sum(axis=0)
denom = np.where(mask, w, 0.0).sum(axis=0)
return np.where(denom > 0.0, num / denom, np.nan)
def _blocks_to_df_by_bin(
blocks: List[Dict[str, Any]],
*,
topk_percents: np.ndarray,
) -> pd.DataFrame:
"""Convert per-block column vectors into the long-format per-bin DataFrame.
This does a single vectorized reshape per block (cause-major ordering), and
concatenates columns once at the end.
"""
if len(blocks) == 0:
return pd.DataFrame(
columns=[
"mc_idx",
"age_bin_id",
"age_bin_low",
"age_bin_high",
"horizon_tau",
"topk_percent",
"cause_id",
"n_samples",
"n_positives",
"auc",
"auprc",
"recall_at_K",
"precision_at_K",
"brier_score",
]
)
P = int(topk_percents.size)
cols: Dict[str, List[np.ndarray]] = {
"mc_idx": [],
"age_bin_id": [],
"age_bin_low": [],
"age_bin_high": [],
"horizon_tau": [],
"topk_percent": [],
"cause_id": [],
"n_samples": [],
"n_positives": [],
"auc": [],
"auprc": [],
"recall_at_K": [],
"precision_at_K": [],
"brier_score": [],
}
for blk in blocks:
cause_id = np.asarray(blk["cause_id"], dtype=int)
K = int(cause_id.size)
n_rows = K * P
cols["mc_idx"].append(np.full(n_rows, int(blk["mc_idx"]), dtype=int))
cols["age_bin_id"].append(
np.full(n_rows, int(blk["age_bin_id"]), dtype=int))
cols["age_bin_low"].append(
np.full(n_rows, float(blk["age_bin_low"]), dtype=float))
cols["age_bin_high"].append(
np.full(n_rows, float(blk["age_bin_high"]), dtype=float))
cols["horizon_tau"].append(
np.full(n_rows, float(blk["horizon_tau"]), dtype=float))
cols["cause_id"].append(np.repeat(cause_id, P))
cols["topk_percent"].append(np.tile(topk_percents.astype(float), K))
cols["n_samples"].append(
np.full(n_rows, int(blk["n_samples"]), dtype=int))
n_pos = np.asarray(blk["n_positives"], dtype=int)
cols["n_positives"].append(np.repeat(n_pos, P))
auc = np.asarray(blk["auc"], dtype=float)
auprc = np.asarray(blk["auprc"], dtype=float)
brier = np.asarray(blk["brier_score"], dtype=float)
cols["auc"].append(np.repeat(auc, P))
cols["auprc"].append(np.repeat(auprc, P))
cols["brier_score"].append(np.repeat(brier, P))
# precision/recall are stored as (P,K); we want cause-major rows, i.e.
# (K,P) then flatten.
prec = np.asarray(blk["precision_at_K"], dtype=float)
rec = np.asarray(blk["recall_at_K"], dtype=float)
if prec.shape != (P, K) or rec.shape != (P, K):
raise ValueError(
f"Expected precision/recall shapes (P,K)=({P},{K}); got {prec.shape} and {rec.shape}"
)
cols["precision_at_K"].append(prec.T.reshape(-1))
cols["recall_at_K"].append(rec.T.reshape(-1))
out = {k: np.concatenate(v, axis=0) for k, v in cols.items()}
return pd.DataFrame(out)
def aggregate_metrics_columnar(
blocks: List[Dict[str, Any]],
*,
topk_percents: np.ndarray,
cause_id: np.ndarray,
) -> pd.DataFrame:
"""Aggregate per-bin results across age bins (macro/weighted) and MC (mean/std).
This is a vectorized, columnar replacement for the old pandas groupby/apply.
Semantics match the previous implementation:
- bins with n_samples==0 are excluded from bin-aggregation
- macro: unweighted mean over bins (NaN-aware)
- weighted: weighted mean over bins using weights=n_samples (NaN-aware)
- across MC: mean/std (ddof=1), NaN-aware
"""
if len(blocks) == 0:
return pd.DataFrame(
columns=[
"agg_type",
"horizon_tau",
"topk_percent",
"cause_id",
"n_mc",
"n_bins_used_mean",
"n_samples_total_mean",
"n_positives_total_mean",
"auc_mean",
"auc_std",
"auprc_mean",
"auprc_std",
"recall_at_K_mean",
"recall_at_K_std",
"precision_at_K_mean",
"precision_at_K_std",
"brier_score_mean",
"brier_score_std",
]
)
P = int(topk_percents.size)
cause_id = np.asarray(cause_id, dtype=int)
K = int(cause_id.size)
# Group blocks by (mc_idx, horizon_tau)
keys: List[Tuple[int, float]] = []
grouped: Dict[Tuple[int, float], List[Dict[str, Any]]] = {}
for blk in blocks:
key = (int(blk["mc_idx"]), float(blk["horizon_tau"]))
if key not in grouped:
grouped[key] = []
keys.append(key)
grouped[key].append(blk)
mc_vals = sorted({k[0] for k in keys})
tau_vals = sorted({k[1] for k in keys})
M = len(mc_vals)
T = len(tau_vals)
mc_index = {mc: i for i, mc in enumerate(mc_vals)}
tau_index = {tau: i for i, tau in enumerate(tau_vals)}
# Per (agg_type, mc, tau): store arrays
# metrics: (M,T,K) and (M,T,P,K)
auc_macro = np.full((M, T, K), np.nan, dtype=float)
auc_weighted = np.full((M, T, K), np.nan, dtype=float)
ap_macro = np.full((M, T, K), np.nan, dtype=float)
ap_weighted = np.full((M, T, K), np.nan, dtype=float)
brier_macro = np.full((M, T, K), np.nan, dtype=float)
brier_weighted = np.full((M, T, K), np.nan, dtype=float)
prec_macro = np.full((M, T, P, K), np.nan, dtype=float)
prec_weighted = np.full((M, T, P, K), np.nan, dtype=float)
rec_macro = np.full((M, T, P, K), np.nan, dtype=float)
rec_weighted = np.full((M, T, P, K), np.nan, dtype=float)
n_bins_used = np.zeros((M, T), dtype=float)
n_samples_total = np.zeros((M, T), dtype=float)
n_pos_total = np.zeros((M, T, K), dtype=float)
for (mc, tau), blks in grouped.items():
mi = mc_index[mc]
ti = tau_index[tau]
# keep only bins with n_samples>0
blks_nz = [b for b in blks if int(b["n_samples"]) > 0]
if len(blks_nz) == 0:
n_bins_used[mi, ti] = 0.0
n_samples_total[mi, ti] = 0.0
n_pos_total[mi, ti, :] = 0.0
continue
w = np.asarray([int(b["n_samples"])
for b in blks_nz], dtype=float) # (B,)
n_bins_used[mi, ti] = float(len(w))
n_samples_total[mi, ti] = float(w.sum())
npos = np.stack([np.asarray(b["n_positives"], dtype=float)
for b in blks_nz], axis=0) # (B,K)
n_pos_total[mi, ti, :] = npos.sum(axis=0)
auc_b = np.stack([np.asarray(b["auc"], dtype=float)
for b in blks_nz], axis=0) # (B,K)
ap_b = np.stack([np.asarray(b["auprc"], dtype=float)
for b in blks_nz], axis=0)
brier_b = np.stack([np.asarray(b["brier_score"], dtype=float)
for b in blks_nz], axis=0)
auc_macro[mi, ti, :] = _nanmean_np(auc_b, axis=0)
ap_macro[mi, ti, :] = _nanmean_np(ap_b, axis=0)
brier_macro[mi, ti, :] = _nanmean_np(brier_b, axis=0)
auc_weighted[mi, ti, :] = _weighted_mean_np(auc_b, w, axis=0)
ap_weighted[mi, ti, :] = _weighted_mean_np(ap_b, w, axis=0)
brier_weighted[mi, ti, :] = _weighted_mean_np(brier_b, w, axis=0)
prec_b = np.stack([np.asarray(b["precision_at_K"], dtype=float)
for b in blks_nz], axis=0) # (B,P,K)
rec_b = np.stack([np.asarray(b["recall_at_K"], dtype=float)
for b in blks_nz], axis=0)
# macro mean over bins
prec_macro[mi, ti, :, :] = _nanmean_np(prec_b, axis=0)
rec_macro[mi, ti, :, :] = _nanmean_np(rec_b, axis=0)
# weighted mean over bins (weights along bin axis)
w3 = w.reshape(-1, 1, 1)
prec_weighted[mi, ti, :, :] = _weighted_mean_np(prec_b, w3, axis=0)
rec_weighted[mi, ti, :, :] = _weighted_mean_np(rec_b, w3, axis=0)
# Across-MC aggregation (mean/std), then emit long-format df keyed by
# (agg_type, horizon_tau, topk_percent, cause_id)
rows: Dict[str, List[np.ndarray]] = {
"agg_type": [],
"horizon_tau": [],
"topk_percent": [],
"cause_id": [],
"n_mc": [],
"n_bins_used_mean": [],
"n_samples_total_mean": [],
"n_positives_total_mean": [],
"auc_mean": [],
"auc_std": [],
"auprc_mean": [],
"auprc_std": [],
"recall_at_K_mean": [],
"recall_at_K_std": [],
"precision_at_K_mean": [],
"precision_at_K_std": [],
"brier_score_mean": [],
"brier_score_std": [],
}
cause_long = np.repeat(cause_id, P)
topk_long = np.tile(topk_percents.astype(float), K)
n_mc_val = float(M)
for ti, tau in enumerate(tau_vals):
# scalar totals (repeat across causes/topk)
n_bins_mean = float(
np.mean(n_bins_used[:, ti])) if M > 0 else float("nan")
n_samp_mean = float(
np.mean(n_samples_total[:, ti])) if M > 0 else float("nan")
n_pos_mean = _nanmean_np(n_pos_total[:, ti, :], axis=0) # (K,)
for agg_type in ("macro", "weighted"):
if agg_type == "macro":
auc_m = _nanmean_np(auc_macro[:, ti, :], axis=0)
auc_s = _nanstd_np_ddof1(auc_macro[:, ti, :], axis=0)
ap_m = _nanmean_np(ap_macro[:, ti, :], axis=0)
ap_s = _nanstd_np_ddof1(ap_macro[:, ti, :], axis=0)
brier_m = _nanmean_np(brier_macro[:, ti, :], axis=0)
brier_s = _nanstd_np_ddof1(brier_macro[:, ti, :], axis=0)
prec_m = _nanmean_np(prec_macro[:, ti, :, :], axis=0) # (P,K)
prec_s = _nanstd_np_ddof1(prec_macro[:, ti, :, :], axis=0)
rec_m = _nanmean_np(rec_macro[:, ti, :, :], axis=0)
rec_s = _nanstd_np_ddof1(rec_macro[:, ti, :, :], axis=0)
else:
auc_m = _nanmean_np(auc_weighted[:, ti, :], axis=0)
auc_s = _nanstd_np_ddof1(auc_weighted[:, ti, :], axis=0)
ap_m = _nanmean_np(ap_weighted[:, ti, :], axis=0)
ap_s = _nanstd_np_ddof1(ap_weighted[:, ti, :], axis=0)
brier_m = _nanmean_np(brier_weighted[:, ti, :], axis=0)
brier_s = _nanstd_np_ddof1(brier_weighted[:, ti, :], axis=0)
prec_m = _nanmean_np(prec_weighted[:, ti, :, :], axis=0)
prec_s = _nanstd_np_ddof1(prec_weighted[:, ti, :, :], axis=0)
rec_m = _nanmean_np(rec_weighted[:, ti, :, :], axis=0)
rec_s = _nanstd_np_ddof1(rec_weighted[:, ti, :, :], axis=0)
n_rows = K * P
rows["agg_type"].append(np.full(n_rows, agg_type, dtype=object))
rows["horizon_tau"].append(np.full(n_rows, float(tau), dtype=float))
rows["topk_percent"].append(topk_long)
rows["cause_id"].append(cause_long)
rows["n_mc"].append(np.full(n_rows, n_mc_val, dtype=float))
rows["n_bins_used_mean"].append(
np.full(n_rows, n_bins_mean, dtype=float))
rows["n_samples_total_mean"].append(
np.full(n_rows, n_samp_mean, dtype=float))
rows["n_positives_total_mean"].append(np.repeat(n_pos_mean, P))
rows["auc_mean"].append(np.repeat(auc_m, P))
rows["auc_std"].append(np.repeat(auc_s, P))
rows["auprc_mean"].append(np.repeat(ap_m, P))
rows["auprc_std"].append(np.repeat(ap_s, P))
rows["brier_score_mean"].append(np.repeat(brier_m, P))
rows["brier_score_std"].append(np.repeat(brier_s, P))
rows["precision_at_K_mean"].append(prec_m.T.reshape(-1))
rows["precision_at_K_std"].append(prec_s.T.reshape(-1))
rows["recall_at_K_mean"].append(rec_m.T.reshape(-1))
rows["recall_at_K_std"].append(rec_s.T.reshape(-1))
out = {k: np.concatenate(v, axis=0) for k, v in rows.items()}
df = pd.DataFrame(out)
return df.sort_values(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True
)
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
"""Aggregate per-bin age evaluation results.
Produces both:
- macro: unweighted mean over bins with n_samples>0
- weighted: weighted mean over bins using weights=n_samples
Then aggregates across MC repetitions (mean/std).
Requires df_by_bin to include:
mc_idx, horizon_tau, topk_percent, cause_id, age_bin_id,
n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score
Returns:
DataFrame keyed by (agg_type, horizon_tau, topk_percent, cause_id)
"""
if df_by_bin is None or len(df_by_bin) == 0:
return pd.DataFrame(
columns=[
"agg_type",
"horizon_tau",
"topk_percent",
"cause_id",
"n_mc",
"n_bins_used_mean",
"n_samples_total_mean",
"n_positives_total_mean",
"auc_mean",
"auc_std",
"auprc_mean",
"auprc_std",
"recall_at_K_mean",
"recall_at_K_std",
"precision_at_K_mean",
"precision_at_K_std",
"brier_score_mean",
"brier_score_std",
]
)
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
g = group[group["n_samples"] > 0]
if len(g) == 0:
return pd.Series(
dict(
n_bins_used=0,
n_samples_total=0,
n_positives_total=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
n_bins_used = int(g["age_bin_id"].nunique())
n_samples_total = int(g["n_samples"].sum())
n_positives_total = int(g["n_positives"].sum())
if not weighted:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float(g["auc"].mean()),
auprc=float(g["auprc"].mean()),
recall_at_K=float(g["recall_at_K"].mean()),
precision_at_K=float(g["precision_at_K"].mean()),
brier_score=float(g["brier_score"].mean()),
)
)
w = g["n_samples"].to_numpy(dtype=float)
w_sum = float(w.sum())
if w_sum <= 0.0:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
def _wavg(col: str) -> float:
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=_wavg("auc"),
auprc=_wavg("auprc"),
recall_at_K=_wavg("recall_at_K"),
precision_at_K=_wavg("precision_at_K"),
brier_score=_wavg("brier_score"),
)
)
# Kept for backward compatibility (e.g., if callers load a CSV and need to
# aggregate). Prefer `aggregate_metrics_columnar` during evaluation.
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
df = df_by_bin[df_by_bin["n_samples"] > 0].copy()
if len(df) == 0:
return pd.DataFrame(
columns=[
"agg_type",
"horizon_tau",
"topk_percent",
"cause_id",
"n_mc",
"n_bins_used_mean",
"n_samples_total_mean",
"n_positives_total_mean",
"auc_mean",
"auc_std",
"auprc_mean",
"auprc_std",
"recall_at_K_mean",
"recall_at_K_std",
"precision_at_K_mean",
"precision_at_K_std",
"brier_score_mean",
"brier_score_std",
]
)
# Macro: mean over bins
df_mc_macro = (
df.groupby(group_keys, as_index=False)
.agg(
n_bins_used=("age_bin_id", "nunique"),
n_samples_total=("n_samples", "sum"),
n_positives_total=("n_positives", "sum"),
auc=("auc", "mean"),
auprc=("auprc", "mean"),
recall_at_K=("recall_at_K", "mean"),
precision_at_K=("precision_at_K", "mean"),
brier_score=("brier_score", "mean"),
)
)
df_mc_macro["agg_type"] = "macro"
# Weighted: weighted mean over bins with weights=n_samples, NaN-aware per metric
w = df["n_samples"].astype(float)
df_w = df.copy()
for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]:
m = df_w[col].astype(float)
ww = w.where(m.notna(), other=0.0)
df_w[f"__num_{col}"] = (m.fillna(0.0) * w)
df_w[f"__den_{col}"] = ww
df_mc_w = df_w.groupby(group_keys, as_index=False).agg(
n_bins_used=("age_bin_id", "nunique"),
n_samples_total=("n_samples", "sum"),
n_positives_total=("n_positives", "sum"),
**{f"__num_{c}": (f"__num_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]},
**{f"__den_{c}": (f"__den_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]},
)
for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]:
num = df_mc_w[f"__num_{col}"].astype(float)
den = df_mc_w[f"__den_{col}"].astype(float)
df_mc_w[col] = (num / den).where(den > 0.0, other=float("nan"))
df_mc_w.drop(columns=[f"__num_{col}", f"__den_{col}"], inplace=True)
df_mc_w["agg_type"] = "weighted"
df_mc_binagg = pd.concat([df_mc_macro, df_mc_w], ignore_index=True)
df_agg = (
df_mc_binagg.groupby(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False)
.agg(
n_mc=("mc_idx", "nunique"),
n_bins_used_mean=("n_bins_used", "mean"),
n_samples_total_mean=("n_samples_total", "mean"),
n_positives_total_mean=("n_positives_total", "mean"),
auc_mean=("auc", "mean"),
auc_std=("auc", "std"),
auprc_mean=("auprc", "mean"),
auprc_std=("auprc", "std"),
recall_at_K_mean=("recall_at_K", "mean"),
recall_at_K_std=("recall_at_K", "std"),
precision_at_K_mean=("precision_at_K", "mean"),
precision_at_K_std=("precision_at_K", "std"),
brier_score_mean=("brier_score", "mean"),
brier_score_std=("brier_score", "std"),
)
.sort_values(["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True)
)
return df_agg
# NOTE: metric computation is torch/GPU-native in `torch_metrics.py`.
# NumPy/Pandas are only used for final CSV formatting/aggregation.
@dataclass
class EvalAgeConfig:
horizons_years: Sequence[float]
age_bins: Sequence[Tuple[float, float]]
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
n_mc: int = 5
seed: int = 0
cause_ids: Optional[Sequence[int]] = None
store_per_cause: bool = True
@torch.inference_mode()
def evaluate_time_dependent_age_bins(
model: torch.nn.Module,
head: torch.nn.Module,
criterion,
dataloader: torch.utils.data.DataLoader,
n_disease: int,
cfg: EvalAgeConfig,
device: str | torch.device,
mc_offset: int = 0,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Delphi-2M-style age-bin evaluation with strict horizon alignment.
Semantics (strict): for each (MC, horizon tau, age bin) we independently:
- build the eligible token set within that bin
- enforce follow-up coverage: t_ctx + tau <= t_end
- randomly sample exactly one token per individual within the bin (de-dup)
- recompute context representations and predictions for that (tau, bin)
Returns:
df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id)
df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted}
"""
device = torch.device(device)
model.eval()
head.eval()
horizons_years = [float(x) for x in cfg.horizons_years]
if len(horizons_years) == 0:
raise ValueError("cfg.horizons_years must be non-empty")
age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins]
if len(age_bins) == 0:
raise ValueError("cfg.age_bins must be non-empty")
for (a, b) in age_bins:
if not (b > a):
raise ValueError(
f"age_bins must be (low, high) with high>low; got {(a, b)}")
topk_percents = [float(x) for x in cfg.topk_percents]
if len(topk_percents) == 0:
raise ValueError("cfg.topk_percents must be non-empty")
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
raise ValueError(
f"All topk_percents must be in (0,100]; got {topk_percents}")
if int(cfg.n_mc) <= 0:
raise ValueError("cfg.n_mc must be >= 1")
if cfg.cause_ids is None:
cause_ids = None
n_causes_eval = int(n_disease)
cause_id_vec = np.arange(n_causes_eval, dtype=int)
else:
cause_ids = torch.tensor(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
cause_id_vec = np.asarray(list(cfg.cause_ids), dtype=int)
topk_percents_np = np.asarray(topk_percents, dtype=float)
# Columnar per-(mc,tau,bin) blocks; avoids Python per-row dict appends.
blocks: List[Dict[str, Any]] = []
for mc_idx in range(int(cfg.n_mc)):
global_mc_idx = int(mc_offset) + int(mc_idx)
# Storage for this MC only: (tau, bin) -> list of GPU tensors.
# This keeps computations GPU-first while preventing a factor-n_mc
# blow-up in GPU memory.
y_true_mc: List[List[List[torch.Tensor]]] = [
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
]
y_pred_mc: List[List[List[torch.Tensor]]] = [
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
]
# tqdm over batches; include MC idx in description.
for batch_idx, batch in enumerate(
tqdm(dataloader,
desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch")
):
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
B = int(event_seq.size(0))
b = torch.arange(B, device=device)
# Hoist backbone forward pass: inputs are identical across (tau, age_bin)
# within this batch, so this is safe and numerically identical.
h = model(event_seq, time_seq, sexes,
cont_feats, cate_feats) # (B,L,D)
for tau_idx, tau_y in enumerate(horizons_years):
tau_tensor = torch.tensor(float(tau_y), device=device)
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
seed = (
int(cfg.seed)
+ (100_000 * int(global_mc_idx))
+ (1_000 * int(tau_idx))
+ (10 * int(bin_idx))
+ int(batch_idx)
)
keep, t_ctx = sample_context_in_fixed_age_bin(
event_seq=event_seq,
time_seq=time_seq,
tau_years=float(tau_y),
age_bin=(float(a_lo), float(a_hi)),
seed=seed,
)
if not keep.any():
continue
# Bin-specific prediction: context indices differ per (tau, bin)
# but the backbone features do not.
c = h[b, t_ctx]
logits = head(c)
cifs = criterion.calculate_cifs(
logits, taus=tau_tensor
)
if cifs.ndim != 2:
raise ValueError(
"criterion.calculate_cifs must return (B,K) for scalar tau; "
f"got shape={tuple(cifs.shape)}"
)
if cause_ids is None:
y = multi_hot_ever_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
n_disease=n_disease,
)
preds = cifs
else:
y = multi_hot_selected_causes_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
cause_ids=cause_ids,
n_disease=n_disease,
)
preds = cifs.index_select(dim=1, index=cause_ids)
y_true_mc[tau_idx][bin_idx].append(
y[keep].detach().to(dtype=torch.bool)
)
y_pred_mc[tau_idx][bin_idx].append(
preds[keep].detach().to(dtype=torch.float32)
)
# Aggregate this MC immediately (frees GPU memory before next MC).
for h_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
if len(y_true_mc[h_idx][bin_idx]) == 0:
# No samples in this bin for this (mc, tau): store a single
# block with NaN metric vectors.
K = int(n_causes_eval)
P = int(topk_percents_np.size)
blocks.append(
dict(
mc_idx=global_mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
n_samples=0,
cause_id=cause_id_vec,
n_positives=np.zeros((K,), dtype=int),
auc=np.full((K,), np.nan, dtype=float),
auprc=np.full((K,), np.nan, dtype=float),
brier_score=np.full((K,), np.nan, dtype=float),
precision_at_K=np.full((P, K), np.nan, dtype=float),
recall_at_K=np.full((P, K), np.nan, dtype=float),
)
)
continue
yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0)
pb_t = torch.cat(y_pred_mc[h_idx][bin_idx], dim=0)
if tuple(yb_t.shape) != tuple(pb_t.shape):
raise ValueError(
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}"
)
n_samples = int(yb_t.size(0))
metrics = compute_binary_metrics_torch(
y_true=yb_t,
y_pred=pb_t,
k_percents=topk_percents,
tie_mode="exact",
chunk_size=128,
compute_ici=False,
)
# Collect a single columnar block (vectors, not per-row dicts).
blocks.append(
dict(
mc_idx=global_mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
n_samples=int(n_samples),
cause_id=cause_id_vec,
n_positives=metrics.n_pos_per_cause.detach().cpu().numpy().astype(int),
auc=metrics.auc_per_cause.detach().cpu().numpy().astype(float),
auprc=metrics.ap_per_cause.detach().cpu().numpy().astype(float),
brier_score=metrics.brier_per_cause.detach().cpu().numpy().astype(float),
precision_at_K=metrics.precision_at_k.detach().cpu().numpy().astype(float),
recall_at_K=metrics.recall_at_k.detach().cpu().numpy().astype(float),
)
)
# Aggregation is computed from columnar blocks (fast, no pandas apply).
df_agg = aggregate_metrics_columnar(
blocks,
topk_percents=topk_percents_np,
cause_id=cause_id_vec,
)
if bool(cfg.store_per_cause):
df_by_bin = _blocks_to_df_by_bin(blocks, topk_percents=topk_percents_np)
else:
df_by_bin = pd.DataFrame(
columns=[
"mc_idx",
"age_bin_id",
"age_bin_low",
"age_bin_high",
"horizon_tau",
"topk_percent",
"cause_id",
"n_samples",
"n_positives",
"auc",
"auprc",
"recall_at_K",
"precision_at_K",
"brier_score",
]
)
return df_by_bin, df_agg