Add function to drop zero-positive rows and update CSV export logic in age-bin evaluation

This commit is contained in:
2026-01-16 17:51:00 +08:00
parent 4068310a12
commit a637beb220
2 changed files with 521 additions and 83 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
@@ -23,6 +23,364 @@ from utils import (
from torch_metrics import compute_binary_metrics_torch
def _nanmean_np(x: np.ndarray, axis: int = 0) -> np.ndarray:
with np.errstate(invalid="ignore"):
return np.nanmean(x, axis=axis)
def _nanstd_np_ddof1(x: np.ndarray, axis: int = 0) -> np.ndarray:
"""NaN-aware sample std (ddof=1), matching pandas std() semantics."""
x = np.asarray(x, dtype=float)
mask = np.isfinite(x)
cnt = mask.sum(axis=axis)
# mean over finite entries
x0 = np.where(mask, x, 0.0)
mean = x0.sum(axis=axis) / np.maximum(cnt, 1)
# sum of squared deviations over finite entries
dev2 = np.where(mask, (x - np.expand_dims(mean, axis=axis)) ** 2, 0.0)
ss = dev2.sum(axis=axis)
denom = cnt - 1
out = np.sqrt(ss / np.maximum(denom, 1))
out = np.where(denom > 0, out, np.nan)
return out
def _weighted_mean_np(x: np.ndarray, w: np.ndarray, axis: int = 0) -> np.ndarray:
"""NaN-aware weighted mean.
Only bins with finite x contribute to both numerator and denominator.
If denom==0 -> NaN.
"""
x = np.asarray(x, dtype=float)
w = np.asarray(w, dtype=float)
if axis != 0:
raise ValueError("_weighted_mean_np currently supports axis=0 only")
# Broadcast weights along trailing dims of x.
while w.ndim < x.ndim:
w = w[..., None]
w = np.broadcast_to(w, x.shape)
mask = np.isfinite(x)
num = np.where(mask, x * w, 0.0).sum(axis=0)
denom = np.where(mask, w, 0.0).sum(axis=0)
return np.where(denom > 0.0, num / denom, np.nan)
def _blocks_to_df_by_bin(
blocks: List[Dict[str, Any]],
*,
topk_percents: np.ndarray,
) -> pd.DataFrame:
"""Convert per-block column vectors into the long-format per-bin DataFrame.
This does a single vectorized reshape per block (cause-major ordering), and
concatenates columns once at the end.
"""
if len(blocks) == 0:
return pd.DataFrame(
columns=[
"mc_idx",
"age_bin_id",
"age_bin_low",
"age_bin_high",
"horizon_tau",
"topk_percent",
"cause_id",
"n_samples",
"n_positives",
"auc",
"auprc",
"recall_at_K",
"precision_at_K",
"brier_score",
]
)
P = int(topk_percents.size)
cols: Dict[str, List[np.ndarray]] = {
"mc_idx": [],
"age_bin_id": [],
"age_bin_low": [],
"age_bin_high": [],
"horizon_tau": [],
"topk_percent": [],
"cause_id": [],
"n_samples": [],
"n_positives": [],
"auc": [],
"auprc": [],
"recall_at_K": [],
"precision_at_K": [],
"brier_score": [],
}
for blk in blocks:
cause_id = np.asarray(blk["cause_id"], dtype=int)
K = int(cause_id.size)
n_rows = K * P
cols["mc_idx"].append(np.full(n_rows, int(blk["mc_idx"]), dtype=int))
cols["age_bin_id"].append(
np.full(n_rows, int(blk["age_bin_id"]), dtype=int))
cols["age_bin_low"].append(
np.full(n_rows, float(blk["age_bin_low"]), dtype=float))
cols["age_bin_high"].append(
np.full(n_rows, float(blk["age_bin_high"]), dtype=float))
cols["horizon_tau"].append(
np.full(n_rows, float(blk["horizon_tau"]), dtype=float))
cols["cause_id"].append(np.repeat(cause_id, P))
cols["topk_percent"].append(np.tile(topk_percents.astype(float), K))
cols["n_samples"].append(
np.full(n_rows, int(blk["n_samples"]), dtype=int))
n_pos = np.asarray(blk["n_positives"], dtype=int)
cols["n_positives"].append(np.repeat(n_pos, P))
auc = np.asarray(blk["auc"], dtype=float)
auprc = np.asarray(blk["auprc"], dtype=float)
brier = np.asarray(blk["brier_score"], dtype=float)
cols["auc"].append(np.repeat(auc, P))
cols["auprc"].append(np.repeat(auprc, P))
cols["brier_score"].append(np.repeat(brier, P))
# precision/recall are stored as (P,K); we want cause-major rows, i.e.
# (K,P) then flatten.
prec = np.asarray(blk["precision_at_K"], dtype=float)
rec = np.asarray(blk["recall_at_K"], dtype=float)
if prec.shape != (P, K) or rec.shape != (P, K):
raise ValueError(
f"Expected precision/recall shapes (P,K)=({P},{K}); got {prec.shape} and {rec.shape}"
)
cols["precision_at_K"].append(prec.T.reshape(-1))
cols["recall_at_K"].append(rec.T.reshape(-1))
out = {k: np.concatenate(v, axis=0) for k, v in cols.items()}
return pd.DataFrame(out)
def aggregate_metrics_columnar(
blocks: List[Dict[str, Any]],
*,
topk_percents: np.ndarray,
cause_id: np.ndarray,
) -> pd.DataFrame:
"""Aggregate per-bin results across age bins (macro/weighted) and MC (mean/std).
This is a vectorized, columnar replacement for the old pandas groupby/apply.
Semantics match the previous implementation:
- bins with n_samples==0 are excluded from bin-aggregation
- macro: unweighted mean over bins (NaN-aware)
- weighted: weighted mean over bins using weights=n_samples (NaN-aware)
- across MC: mean/std (ddof=1), NaN-aware
"""
if len(blocks) == 0:
return pd.DataFrame(
columns=[
"agg_type",
"horizon_tau",
"topk_percent",
"cause_id",
"n_mc",
"n_bins_used_mean",
"n_samples_total_mean",
"n_positives_total_mean",
"auc_mean",
"auc_std",
"auprc_mean",
"auprc_std",
"recall_at_K_mean",
"recall_at_K_std",
"precision_at_K_mean",
"precision_at_K_std",
"brier_score_mean",
"brier_score_std",
]
)
P = int(topk_percents.size)
cause_id = np.asarray(cause_id, dtype=int)
K = int(cause_id.size)
# Group blocks by (mc_idx, horizon_tau)
keys: List[Tuple[int, float]] = []
grouped: Dict[Tuple[int, float], List[Dict[str, Any]]] = {}
for blk in blocks:
key = (int(blk["mc_idx"]), float(blk["horizon_tau"]))
if key not in grouped:
grouped[key] = []
keys.append(key)
grouped[key].append(blk)
mc_vals = sorted({k[0] for k in keys})
tau_vals = sorted({k[1] for k in keys})
M = len(mc_vals)
T = len(tau_vals)
mc_index = {mc: i for i, mc in enumerate(mc_vals)}
tau_index = {tau: i for i, tau in enumerate(tau_vals)}
# Per (agg_type, mc, tau): store arrays
# metrics: (M,T,K) and (M,T,P,K)
auc_macro = np.full((M, T, K), np.nan, dtype=float)
auc_weighted = np.full((M, T, K), np.nan, dtype=float)
ap_macro = np.full((M, T, K), np.nan, dtype=float)
ap_weighted = np.full((M, T, K), np.nan, dtype=float)
brier_macro = np.full((M, T, K), np.nan, dtype=float)
brier_weighted = np.full((M, T, K), np.nan, dtype=float)
prec_macro = np.full((M, T, P, K), np.nan, dtype=float)
prec_weighted = np.full((M, T, P, K), np.nan, dtype=float)
rec_macro = np.full((M, T, P, K), np.nan, dtype=float)
rec_weighted = np.full((M, T, P, K), np.nan, dtype=float)
n_bins_used = np.zeros((M, T), dtype=float)
n_samples_total = np.zeros((M, T), dtype=float)
n_pos_total = np.zeros((M, T, K), dtype=float)
for (mc, tau), blks in grouped.items():
mi = mc_index[mc]
ti = tau_index[tau]
# keep only bins with n_samples>0
blks_nz = [b for b in blks if int(b["n_samples"]) > 0]
if len(blks_nz) == 0:
n_bins_used[mi, ti] = 0.0
n_samples_total[mi, ti] = 0.0
n_pos_total[mi, ti, :] = 0.0
continue
w = np.asarray([int(b["n_samples"])
for b in blks_nz], dtype=float) # (B,)
n_bins_used[mi, ti] = float(len(w))
n_samples_total[mi, ti] = float(w.sum())
npos = np.stack([np.asarray(b["n_positives"], dtype=float)
for b in blks_nz], axis=0) # (B,K)
n_pos_total[mi, ti, :] = npos.sum(axis=0)
auc_b = np.stack([np.asarray(b["auc"], dtype=float)
for b in blks_nz], axis=0) # (B,K)
ap_b = np.stack([np.asarray(b["auprc"], dtype=float)
for b in blks_nz], axis=0)
brier_b = np.stack([np.asarray(b["brier_score"], dtype=float)
for b in blks_nz], axis=0)
auc_macro[mi, ti, :] = _nanmean_np(auc_b, axis=0)
ap_macro[mi, ti, :] = _nanmean_np(ap_b, axis=0)
brier_macro[mi, ti, :] = _nanmean_np(brier_b, axis=0)
auc_weighted[mi, ti, :] = _weighted_mean_np(auc_b, w, axis=0)
ap_weighted[mi, ti, :] = _weighted_mean_np(ap_b, w, axis=0)
brier_weighted[mi, ti, :] = _weighted_mean_np(brier_b, w, axis=0)
prec_b = np.stack([np.asarray(b["precision_at_K"], dtype=float)
for b in blks_nz], axis=0) # (B,P,K)
rec_b = np.stack([np.asarray(b["recall_at_K"], dtype=float)
for b in blks_nz], axis=0)
# macro mean over bins
prec_macro[mi, ti, :, :] = _nanmean_np(prec_b, axis=0)
rec_macro[mi, ti, :, :] = _nanmean_np(rec_b, axis=0)
# weighted mean over bins (weights along bin axis)
w3 = w.reshape(-1, 1, 1)
prec_weighted[mi, ti, :, :] = _weighted_mean_np(prec_b, w3, axis=0)
rec_weighted[mi, ti, :, :] = _weighted_mean_np(rec_b, w3, axis=0)
# Across-MC aggregation (mean/std), then emit long-format df keyed by
# (agg_type, horizon_tau, topk_percent, cause_id)
rows: Dict[str, List[np.ndarray]] = {
"agg_type": [],
"horizon_tau": [],
"topk_percent": [],
"cause_id": [],
"n_mc": [],
"n_bins_used_mean": [],
"n_samples_total_mean": [],
"n_positives_total_mean": [],
"auc_mean": [],
"auc_std": [],
"auprc_mean": [],
"auprc_std": [],
"recall_at_K_mean": [],
"recall_at_K_std": [],
"precision_at_K_mean": [],
"precision_at_K_std": [],
"brier_score_mean": [],
"brier_score_std": [],
}
cause_long = np.repeat(cause_id, P)
topk_long = np.tile(topk_percents.astype(float), K)
n_mc_val = float(M)
for ti, tau in enumerate(tau_vals):
# scalar totals (repeat across causes/topk)
n_bins_mean = float(
np.mean(n_bins_used[:, ti])) if M > 0 else float("nan")
n_samp_mean = float(
np.mean(n_samples_total[:, ti])) if M > 0 else float("nan")
n_pos_mean = _nanmean_np(n_pos_total[:, ti, :], axis=0) # (K,)
for agg_type in ("macro", "weighted"):
if agg_type == "macro":
auc_m = _nanmean_np(auc_macro[:, ti, :], axis=0)
auc_s = _nanstd_np_ddof1(auc_macro[:, ti, :], axis=0)
ap_m = _nanmean_np(ap_macro[:, ti, :], axis=0)
ap_s = _nanstd_np_ddof1(ap_macro[:, ti, :], axis=0)
brier_m = _nanmean_np(brier_macro[:, ti, :], axis=0)
brier_s = _nanstd_np_ddof1(brier_macro[:, ti, :], axis=0)
prec_m = _nanmean_np(prec_macro[:, ti, :, :], axis=0) # (P,K)
prec_s = _nanstd_np_ddof1(prec_macro[:, ti, :, :], axis=0)
rec_m = _nanmean_np(rec_macro[:, ti, :, :], axis=0)
rec_s = _nanstd_np_ddof1(rec_macro[:, ti, :, :], axis=0)
else:
auc_m = _nanmean_np(auc_weighted[:, ti, :], axis=0)
auc_s = _nanstd_np_ddof1(auc_weighted[:, ti, :], axis=0)
ap_m = _nanmean_np(ap_weighted[:, ti, :], axis=0)
ap_s = _nanstd_np_ddof1(ap_weighted[:, ti, :], axis=0)
brier_m = _nanmean_np(brier_weighted[:, ti, :], axis=0)
brier_s = _nanstd_np_ddof1(brier_weighted[:, ti, :], axis=0)
prec_m = _nanmean_np(prec_weighted[:, ti, :, :], axis=0)
prec_s = _nanstd_np_ddof1(prec_weighted[:, ti, :, :], axis=0)
rec_m = _nanmean_np(rec_weighted[:, ti, :, :], axis=0)
rec_s = _nanstd_np_ddof1(rec_weighted[:, ti, :, :], axis=0)
n_rows = K * P
rows["agg_type"].append(np.full(n_rows, agg_type, dtype=object))
rows["horizon_tau"].append(np.full(n_rows, float(tau), dtype=float))
rows["topk_percent"].append(topk_long)
rows["cause_id"].append(cause_long)
rows["n_mc"].append(np.full(n_rows, n_mc_val, dtype=float))
rows["n_bins_used_mean"].append(
np.full(n_rows, n_bins_mean, dtype=float))
rows["n_samples_total_mean"].append(
np.full(n_rows, n_samp_mean, dtype=float))
rows["n_positives_total_mean"].append(np.repeat(n_pos_mean, P))
rows["auc_mean"].append(np.repeat(auc_m, P))
rows["auc_std"].append(np.repeat(auc_s, P))
rows["auprc_mean"].append(np.repeat(ap_m, P))
rows["auprc_std"].append(np.repeat(ap_s, P))
rows["brier_score_mean"].append(np.repeat(brier_m, P))
rows["brier_score_std"].append(np.repeat(brier_s, P))
rows["precision_at_K_mean"].append(prec_m.T.reshape(-1))
rows["precision_at_K_std"].append(prec_s.T.reshape(-1))
rows["recall_at_K_mean"].append(rec_m.T.reshape(-1))
rows["recall_at_K_std"].append(rec_s.T.reshape(-1))
out = {k: np.concatenate(v, axis=0) for k, v in rows.items()}
df = pd.DataFrame(out)
return df.sort_values(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True
)
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
"""Aggregate per-bin age evaluation results.
@@ -129,34 +487,79 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
)
)
# Kept for backward compatibility (e.g., if callers load a CSV and need to
# aggregate). Prefer `aggregate_metrics_columnar` during evaluation.
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
gb = df_by_bin.groupby(group_keys)
df = df_by_bin[df_by_bin["n_samples"] > 0].copy()
if len(df) == 0:
return pd.DataFrame(
columns=[
"agg_type",
"horizon_tau",
"topk_percent",
"cause_id",
"n_mc",
"n_bins_used_mean",
"n_samples_total_mean",
"n_positives_total_mean",
"auc_mean",
"auc_std",
"auprc_mean",
"auprc_std",
"recall_at_K_mean",
"recall_at_K_std",
"precision_at_K_mean",
"precision_at_K_std",
"brier_score_mean",
"brier_score_std",
]
)
try:
df_mc_macro = gb.apply(
lambda g: _bin_aggregate(g, weighted=False), include_groups=False
).reset_index()
except TypeError: # pandas<2.2 (no include_groups)
df_mc_macro = gb.apply(lambda g: _bin_aggregate(
g, weighted=False)).reset_index()
# Macro: mean over bins
df_mc_macro = (
df.groupby(group_keys, as_index=False)
.agg(
n_bins_used=("age_bin_id", "nunique"),
n_samples_total=("n_samples", "sum"),
n_positives_total=("n_positives", "sum"),
auc=("auc", "mean"),
auprc=("auprc", "mean"),
recall_at_K=("recall_at_K", "mean"),
precision_at_K=("precision_at_K", "mean"),
brier_score=("brier_score", "mean"),
)
)
df_mc_macro["agg_type"] = "macro"
try:
df_mc_weighted = gb.apply(
lambda g: _bin_aggregate(g, weighted=True), include_groups=False
).reset_index()
except TypeError: # pandas<2.2 (no include_groups)
df_mc_weighted = gb.apply(
lambda g: _bin_aggregate(g, weighted=True)).reset_index()
df_mc_weighted["agg_type"] = "weighted"
# Weighted: weighted mean over bins with weights=n_samples, NaN-aware per metric
w = df["n_samples"].astype(float)
df_w = df.copy()
for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]:
m = df_w[col].astype(float)
ww = w.where(m.notna(), other=0.0)
df_w[f"__num_{col}"] = (m.fillna(0.0) * w)
df_w[f"__den_{col}"] = ww
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
df_mc_w = df_w.groupby(group_keys, as_index=False).agg(
n_bins_used=("age_bin_id", "nunique"),
n_samples_total=("n_samples", "sum"),
n_positives_total=("n_positives", "sum"),
**{f"__num_{c}": (f"__num_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]},
**{f"__den_{c}": (f"__den_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]},
)
for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]:
num = df_mc_w[f"__num_{col}"].astype(float)
den = df_mc_w[f"__den_{col}"].astype(float)
df_mc_w[col] = (num / den).where(den > 0.0, other=float("nan"))
df_mc_w.drop(columns=[f"__num_{col}", f"__den_{col}"], inplace=True)
df_mc_w["agg_type"] = "weighted"
df_mc_binagg = pd.concat([df_mc_macro, df_mc_w], ignore_index=True)
df_agg = (
df_mc_binagg.groupby(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
)
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False)
.agg(
n_mc=("mc_idx", "nunique"),
n_bins_used_mean=("n_bins_used", "mean"),
@@ -173,10 +576,7 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
brier_score_mean=("brier_score", "mean"),
brier_score_std=("brier_score", "std"),
)
.sort_values(
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
ignore_index=True,
)
.sort_values(["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True)
)
return df_agg
@@ -193,6 +593,7 @@ class EvalAgeConfig:
n_mc: int = 5
seed: int = 0
cause_ids: Optional[Sequence[int]] = None
store_per_cause: bool = True
@torch.inference_mode()
@@ -247,12 +648,17 @@ def evaluate_time_dependent_age_bins(
if cfg.cause_ids is None:
cause_ids = None
n_causes_eval = int(n_disease)
cause_id_vec = np.arange(n_causes_eval, dtype=int)
else:
cause_ids = torch.tensor(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
cause_id_vec = np.asarray(list(cfg.cause_ids), dtype=int)
rows_by_bin: List[Dict[str, float | int]] = []
topk_percents_np = np.asarray(topk_percents, dtype=float)
# Columnar per-(mc,tau,bin) blocks; avoids Python per-row dict appends.
blocks: List[Dict[str, Any]] = []
for mc_idx in range(int(cfg.n_mc)):
global_mc_idx = int(mc_offset) + int(mc_idx)
@@ -354,29 +760,27 @@ def evaluate_time_dependent_age_bins(
for h_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
if len(y_true_mc[h_idx][bin_idx]) == 0:
# No samples in this bin for this (mc, tau)
for cause_k in range(n_causes_eval):
cause_id = int(cause_k) if cause_ids is None else int(
cfg.cause_ids[cause_k])
for k_percent in topk_percents:
rows_by_bin.append(
dict(
mc_idx=global_mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
cause_id=cause_id,
n_samples=0,
n_positives=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
# No samples in this bin for this (mc, tau): store a single
# block with NaN metric vectors.
K = int(n_causes_eval)
P = int(topk_percents_np.size)
blocks.append(
dict(
mc_idx=global_mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
n_samples=0,
cause_id=cause_id_vec,
n_positives=np.zeros((K,), dtype=int),
auc=np.full((K,), np.nan, dtype=float),
auprc=np.full((K,), np.nan, dtype=float),
brier_score=np.full((K,), np.nan, dtype=float),
precision_at_K=np.full((P, K), np.nan, dtype=float),
recall_at_K=np.full((P, K), np.nan, dtype=float),
)
)
continue
yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0)
@@ -397,40 +801,52 @@ def evaluate_time_dependent_age_bins(
compute_ici=False,
)
# Move just the metric vectors to CPU once per (mc, tau, bin)
# for DataFrame construction.
auc = metrics.auc_per_cause.detach().cpu().numpy()
auprc = metrics.ap_per_cause.detach().cpu().numpy()
brier = metrics.brier_per_cause.detach().cpu().numpy()
n_pos = metrics.n_pos_per_cause.detach().cpu().numpy()
prec_at_k = metrics.precision_at_k.detach().cpu().numpy() # (P,K)
rec_at_k = metrics.recall_at_k.detach().cpu().numpy() # (P,K)
# Collect a single columnar block (vectors, not per-row dicts).
blocks.append(
dict(
mc_idx=global_mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
n_samples=int(n_samples),
cause_id=cause_id_vec,
n_positives=metrics.n_pos_per_cause.detach().cpu().numpy().astype(int),
auc=metrics.auc_per_cause.detach().cpu().numpy().astype(float),
auprc=metrics.ap_per_cause.detach().cpu().numpy().astype(float),
brier_score=metrics.brier_per_cause.detach().cpu().numpy().astype(float),
precision_at_K=metrics.precision_at_k.detach().cpu().numpy().astype(float),
recall_at_K=metrics.recall_at_k.detach().cpu().numpy().astype(float),
)
)
for cause_k in range(n_causes_eval):
cause_id = int(cause_k) if cause_ids is None else int(
cfg.cause_ids[cause_k])
for p_idx, k_percent in enumerate(topk_percents):
rows_by_bin.append(
dict(
mc_idx=global_mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
cause_id=cause_id,
n_samples=n_samples,
n_positives=int(n_pos[cause_k]),
auc=float(auc[cause_k]),
auprc=float(auprc[cause_k]),
recall_at_K=float(rec_at_k[p_idx, cause_k]),
precision_at_K=float(prec_at_k[p_idx, cause_k]),
brier_score=float(brier[cause_k]),
)
)
# Aggregation is computed from columnar blocks (fast, no pandas apply).
df_agg = aggregate_metrics_columnar(
blocks,
topk_percents=topk_percents_np,
cause_id=cause_id_vec,
)
df_by_bin = pd.DataFrame(rows_by_bin)
df_agg = aggregate_age_bin_results(df_by_bin)
if bool(cfg.store_per_cause):
df_by_bin = _blocks_to_df_by_bin(blocks, topk_percents=topk_percents_np)
else:
df_by_bin = pd.DataFrame(
columns=[
"mc_idx",
"age_bin_id",
"age_bin_low",
"age_bin_high",
"horizon_tau",
"topk_percent",
"cause_id",
"n_samples",
"n_positives",
"auc",
"auprc",
"recall_at_K",
"precision_at_K",
"brier_score",
]
)
return df_by_bin, df_agg