Add function to drop zero-positive rows and update CSV export logic in age-bin evaluation

This commit is contained in:
2026-01-16 17:51:00 +08:00
parent 4068310a12
commit a637beb220
2 changed files with 521 additions and 83 deletions

View File

@@ -177,6 +177,7 @@ def _worker_eval_mcs_on_gpu(
df_all = pd.concat(frames, ignore_index=True) if len( df_all = pd.concat(frames, ignore_index=True) if len(
frames) else pd.DataFrame() frames) else pd.DataFrame()
df_all = _drop_zero_positives_rows(df_all, "n_positives")
df_all.to_csv(out_path, index=False) df_all.to_csv(out_path, index=False)
queue.put({"ok": True, "out_path": out_path}) queue.put({"ok": True, "out_path": out_path})
except Exception as e: except Exception as e:
@@ -211,6 +212,20 @@ def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lamb
raise ValueError(f"Unsupported loss_type: {loss_type}") raise ValueError(f"Unsupported loss_type: {loss_type}")
def _drop_zero_positives_rows(df: pd.DataFrame, positive_col: str) -> pd.DataFrame:
"""Drop rows where the provided positives column is <= 0.
Intended to reduce CSV size by omitting (cause, horizon, bin) rows that have
no positives, which otherwise yield undefined/NaN metrics.
"""
if df is None or len(df) == 0:
return df
if positive_col not in df.columns:
return df
pos = pd.to_numeric(df[positive_col], errors="coerce")
return df[pos > 0].copy()
def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict): def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict):
if model_type == "delphi_fork": if model_type == "delphi_fork":
return DelphiFork( return DelphiFork(
@@ -404,8 +419,10 @@ def main() -> None:
device=device, device=device,
) )
df_by_bin.to_csv(out_bin, index=False) df_by_bin_csv = _drop_zero_positives_rows(df_by_bin, "n_positives")
df_agg.to_csv(out_agg, index=False) df_agg_csv = _drop_zero_positives_rows(df_agg, "n_positives_total_mean")
df_by_bin_csv.to_csv(out_bin, index=False)
df_agg_csv.to_csv(out_agg, index=False)
print(f"Wrote: {out_bin}") print(f"Wrote: {out_bin}")
print(f"Wrote: {out_agg}") print(f"Wrote: {out_agg}")
return return
@@ -464,8 +481,13 @@ def main() -> None:
frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)] frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)]
df_by_bin = pd.concat(frames, ignore_index=True) if len( df_by_bin = pd.concat(frames, ignore_index=True) if len(
frames) else pd.DataFrame() frames) else pd.DataFrame()
# Ensure we don't keep zero-positive rows even if a temp file was produced
# by an older version of the worker.
df_by_bin = _drop_zero_positives_rows(df_by_bin, "n_positives")
df_agg = aggregate_age_bin_results(df_by_bin) df_agg = aggregate_age_bin_results(df_by_bin)
df_agg = _drop_zero_positives_rows(df_agg, "n_positives_total_mean")
df_by_bin.to_csv(out_bin, index=False) df_by_bin.to_csv(out_bin, index=False)
df_agg.to_csv(out_agg, index=False) df_agg.to_csv(out_agg, index=False)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@@ -23,6 +23,364 @@ from utils import (
from torch_metrics import compute_binary_metrics_torch from torch_metrics import compute_binary_metrics_torch
def _nanmean_np(x: np.ndarray, axis: int = 0) -> np.ndarray:
with np.errstate(invalid="ignore"):
return np.nanmean(x, axis=axis)
def _nanstd_np_ddof1(x: np.ndarray, axis: int = 0) -> np.ndarray:
"""NaN-aware sample std (ddof=1), matching pandas std() semantics."""
x = np.asarray(x, dtype=float)
mask = np.isfinite(x)
cnt = mask.sum(axis=axis)
# mean over finite entries
x0 = np.where(mask, x, 0.0)
mean = x0.sum(axis=axis) / np.maximum(cnt, 1)
# sum of squared deviations over finite entries
dev2 = np.where(mask, (x - np.expand_dims(mean, axis=axis)) ** 2, 0.0)
ss = dev2.sum(axis=axis)
denom = cnt - 1
out = np.sqrt(ss / np.maximum(denom, 1))
out = np.where(denom > 0, out, np.nan)
return out
def _weighted_mean_np(x: np.ndarray, w: np.ndarray, axis: int = 0) -> np.ndarray:
"""NaN-aware weighted mean.
Only bins with finite x contribute to both numerator and denominator.
If denom==0 -> NaN.
"""
x = np.asarray(x, dtype=float)
w = np.asarray(w, dtype=float)
if axis != 0:
raise ValueError("_weighted_mean_np currently supports axis=0 only")
# Broadcast weights along trailing dims of x.
while w.ndim < x.ndim:
w = w[..., None]
w = np.broadcast_to(w, x.shape)
mask = np.isfinite(x)
num = np.where(mask, x * w, 0.0).sum(axis=0)
denom = np.where(mask, w, 0.0).sum(axis=0)
return np.where(denom > 0.0, num / denom, np.nan)
def _blocks_to_df_by_bin(
blocks: List[Dict[str, Any]],
*,
topk_percents: np.ndarray,
) -> pd.DataFrame:
"""Convert per-block column vectors into the long-format per-bin DataFrame.
This does a single vectorized reshape per block (cause-major ordering), and
concatenates columns once at the end.
"""
if len(blocks) == 0:
return pd.DataFrame(
columns=[
"mc_idx",
"age_bin_id",
"age_bin_low",
"age_bin_high",
"horizon_tau",
"topk_percent",
"cause_id",
"n_samples",
"n_positives",
"auc",
"auprc",
"recall_at_K",
"precision_at_K",
"brier_score",
]
)
P = int(topk_percents.size)
cols: Dict[str, List[np.ndarray]] = {
"mc_idx": [],
"age_bin_id": [],
"age_bin_low": [],
"age_bin_high": [],
"horizon_tau": [],
"topk_percent": [],
"cause_id": [],
"n_samples": [],
"n_positives": [],
"auc": [],
"auprc": [],
"recall_at_K": [],
"precision_at_K": [],
"brier_score": [],
}
for blk in blocks:
cause_id = np.asarray(blk["cause_id"], dtype=int)
K = int(cause_id.size)
n_rows = K * P
cols["mc_idx"].append(np.full(n_rows, int(blk["mc_idx"]), dtype=int))
cols["age_bin_id"].append(
np.full(n_rows, int(blk["age_bin_id"]), dtype=int))
cols["age_bin_low"].append(
np.full(n_rows, float(blk["age_bin_low"]), dtype=float))
cols["age_bin_high"].append(
np.full(n_rows, float(blk["age_bin_high"]), dtype=float))
cols["horizon_tau"].append(
np.full(n_rows, float(blk["horizon_tau"]), dtype=float))
cols["cause_id"].append(np.repeat(cause_id, P))
cols["topk_percent"].append(np.tile(topk_percents.astype(float), K))
cols["n_samples"].append(
np.full(n_rows, int(blk["n_samples"]), dtype=int))
n_pos = np.asarray(blk["n_positives"], dtype=int)
cols["n_positives"].append(np.repeat(n_pos, P))
auc = np.asarray(blk["auc"], dtype=float)
auprc = np.asarray(blk["auprc"], dtype=float)
brier = np.asarray(blk["brier_score"], dtype=float)
cols["auc"].append(np.repeat(auc, P))
cols["auprc"].append(np.repeat(auprc, P))
cols["brier_score"].append(np.repeat(brier, P))
# precision/recall are stored as (P,K); we want cause-major rows, i.e.
# (K,P) then flatten.
prec = np.asarray(blk["precision_at_K"], dtype=float)
rec = np.asarray(blk["recall_at_K"], dtype=float)
if prec.shape != (P, K) or rec.shape != (P, K):
raise ValueError(
f"Expected precision/recall shapes (P,K)=({P},{K}); got {prec.shape} and {rec.shape}"
)
cols["precision_at_K"].append(prec.T.reshape(-1))
cols["recall_at_K"].append(rec.T.reshape(-1))
out = {k: np.concatenate(v, axis=0) for k, v in cols.items()}
return pd.DataFrame(out)
def aggregate_metrics_columnar(
blocks: List[Dict[str, Any]],
*,
topk_percents: np.ndarray,
cause_id: np.ndarray,
) -> pd.DataFrame:
"""Aggregate per-bin results across age bins (macro/weighted) and MC (mean/std).
This is a vectorized, columnar replacement for the old pandas groupby/apply.
Semantics match the previous implementation:
- bins with n_samples==0 are excluded from bin-aggregation
- macro: unweighted mean over bins (NaN-aware)
- weighted: weighted mean over bins using weights=n_samples (NaN-aware)
- across MC: mean/std (ddof=1), NaN-aware
"""
if len(blocks) == 0:
return pd.DataFrame(
columns=[
"agg_type",
"horizon_tau",
"topk_percent",
"cause_id",
"n_mc",
"n_bins_used_mean",
"n_samples_total_mean",
"n_positives_total_mean",
"auc_mean",
"auc_std",
"auprc_mean",
"auprc_std",
"recall_at_K_mean",
"recall_at_K_std",
"precision_at_K_mean",
"precision_at_K_std",
"brier_score_mean",
"brier_score_std",
]
)
P = int(topk_percents.size)
cause_id = np.asarray(cause_id, dtype=int)
K = int(cause_id.size)
# Group blocks by (mc_idx, horizon_tau)
keys: List[Tuple[int, float]] = []
grouped: Dict[Tuple[int, float], List[Dict[str, Any]]] = {}
for blk in blocks:
key = (int(blk["mc_idx"]), float(blk["horizon_tau"]))
if key not in grouped:
grouped[key] = []
keys.append(key)
grouped[key].append(blk)
mc_vals = sorted({k[0] for k in keys})
tau_vals = sorted({k[1] for k in keys})
M = len(mc_vals)
T = len(tau_vals)
mc_index = {mc: i for i, mc in enumerate(mc_vals)}
tau_index = {tau: i for i, tau in enumerate(tau_vals)}
# Per (agg_type, mc, tau): store arrays
# metrics: (M,T,K) and (M,T,P,K)
auc_macro = np.full((M, T, K), np.nan, dtype=float)
auc_weighted = np.full((M, T, K), np.nan, dtype=float)
ap_macro = np.full((M, T, K), np.nan, dtype=float)
ap_weighted = np.full((M, T, K), np.nan, dtype=float)
brier_macro = np.full((M, T, K), np.nan, dtype=float)
brier_weighted = np.full((M, T, K), np.nan, dtype=float)
prec_macro = np.full((M, T, P, K), np.nan, dtype=float)
prec_weighted = np.full((M, T, P, K), np.nan, dtype=float)
rec_macro = np.full((M, T, P, K), np.nan, dtype=float)
rec_weighted = np.full((M, T, P, K), np.nan, dtype=float)
n_bins_used = np.zeros((M, T), dtype=float)
n_samples_total = np.zeros((M, T), dtype=float)
n_pos_total = np.zeros((M, T, K), dtype=float)
for (mc, tau), blks in grouped.items():
mi = mc_index[mc]
ti = tau_index[tau]
# keep only bins with n_samples>0
blks_nz = [b for b in blks if int(b["n_samples"]) > 0]
if len(blks_nz) == 0:
n_bins_used[mi, ti] = 0.0
n_samples_total[mi, ti] = 0.0
n_pos_total[mi, ti, :] = 0.0
continue
w = np.asarray([int(b["n_samples"])
for b in blks_nz], dtype=float) # (B,)
n_bins_used[mi, ti] = float(len(w))
n_samples_total[mi, ti] = float(w.sum())
npos = np.stack([np.asarray(b["n_positives"], dtype=float)
for b in blks_nz], axis=0) # (B,K)
n_pos_total[mi, ti, :] = npos.sum(axis=0)
auc_b = np.stack([np.asarray(b["auc"], dtype=float)
for b in blks_nz], axis=0) # (B,K)
ap_b = np.stack([np.asarray(b["auprc"], dtype=float)
for b in blks_nz], axis=0)
brier_b = np.stack([np.asarray(b["brier_score"], dtype=float)
for b in blks_nz], axis=0)
auc_macro[mi, ti, :] = _nanmean_np(auc_b, axis=0)
ap_macro[mi, ti, :] = _nanmean_np(ap_b, axis=0)
brier_macro[mi, ti, :] = _nanmean_np(brier_b, axis=0)
auc_weighted[mi, ti, :] = _weighted_mean_np(auc_b, w, axis=0)
ap_weighted[mi, ti, :] = _weighted_mean_np(ap_b, w, axis=0)
brier_weighted[mi, ti, :] = _weighted_mean_np(brier_b, w, axis=0)
prec_b = np.stack([np.asarray(b["precision_at_K"], dtype=float)
for b in blks_nz], axis=0) # (B,P,K)
rec_b = np.stack([np.asarray(b["recall_at_K"], dtype=float)
for b in blks_nz], axis=0)
# macro mean over bins
prec_macro[mi, ti, :, :] = _nanmean_np(prec_b, axis=0)
rec_macro[mi, ti, :, :] = _nanmean_np(rec_b, axis=0)
# weighted mean over bins (weights along bin axis)
w3 = w.reshape(-1, 1, 1)
prec_weighted[mi, ti, :, :] = _weighted_mean_np(prec_b, w3, axis=0)
rec_weighted[mi, ti, :, :] = _weighted_mean_np(rec_b, w3, axis=0)
# Across-MC aggregation (mean/std), then emit long-format df keyed by
# (agg_type, horizon_tau, topk_percent, cause_id)
rows: Dict[str, List[np.ndarray]] = {
"agg_type": [],
"horizon_tau": [],
"topk_percent": [],
"cause_id": [],
"n_mc": [],
"n_bins_used_mean": [],
"n_samples_total_mean": [],
"n_positives_total_mean": [],
"auc_mean": [],
"auc_std": [],
"auprc_mean": [],
"auprc_std": [],
"recall_at_K_mean": [],
"recall_at_K_std": [],
"precision_at_K_mean": [],
"precision_at_K_std": [],
"brier_score_mean": [],
"brier_score_std": [],
}
cause_long = np.repeat(cause_id, P)
topk_long = np.tile(topk_percents.astype(float), K)
n_mc_val = float(M)
for ti, tau in enumerate(tau_vals):
# scalar totals (repeat across causes/topk)
n_bins_mean = float(
np.mean(n_bins_used[:, ti])) if M > 0 else float("nan")
n_samp_mean = float(
np.mean(n_samples_total[:, ti])) if M > 0 else float("nan")
n_pos_mean = _nanmean_np(n_pos_total[:, ti, :], axis=0) # (K,)
for agg_type in ("macro", "weighted"):
if agg_type == "macro":
auc_m = _nanmean_np(auc_macro[:, ti, :], axis=0)
auc_s = _nanstd_np_ddof1(auc_macro[:, ti, :], axis=0)
ap_m = _nanmean_np(ap_macro[:, ti, :], axis=0)
ap_s = _nanstd_np_ddof1(ap_macro[:, ti, :], axis=0)
brier_m = _nanmean_np(brier_macro[:, ti, :], axis=0)
brier_s = _nanstd_np_ddof1(brier_macro[:, ti, :], axis=0)
prec_m = _nanmean_np(prec_macro[:, ti, :, :], axis=0) # (P,K)
prec_s = _nanstd_np_ddof1(prec_macro[:, ti, :, :], axis=0)
rec_m = _nanmean_np(rec_macro[:, ti, :, :], axis=0)
rec_s = _nanstd_np_ddof1(rec_macro[:, ti, :, :], axis=0)
else:
auc_m = _nanmean_np(auc_weighted[:, ti, :], axis=0)
auc_s = _nanstd_np_ddof1(auc_weighted[:, ti, :], axis=0)
ap_m = _nanmean_np(ap_weighted[:, ti, :], axis=0)
ap_s = _nanstd_np_ddof1(ap_weighted[:, ti, :], axis=0)
brier_m = _nanmean_np(brier_weighted[:, ti, :], axis=0)
brier_s = _nanstd_np_ddof1(brier_weighted[:, ti, :], axis=0)
prec_m = _nanmean_np(prec_weighted[:, ti, :, :], axis=0)
prec_s = _nanstd_np_ddof1(prec_weighted[:, ti, :, :], axis=0)
rec_m = _nanmean_np(rec_weighted[:, ti, :, :], axis=0)
rec_s = _nanstd_np_ddof1(rec_weighted[:, ti, :, :], axis=0)
n_rows = K * P
rows["agg_type"].append(np.full(n_rows, agg_type, dtype=object))
rows["horizon_tau"].append(np.full(n_rows, float(tau), dtype=float))
rows["topk_percent"].append(topk_long)
rows["cause_id"].append(cause_long)
rows["n_mc"].append(np.full(n_rows, n_mc_val, dtype=float))
rows["n_bins_used_mean"].append(
np.full(n_rows, n_bins_mean, dtype=float))
rows["n_samples_total_mean"].append(
np.full(n_rows, n_samp_mean, dtype=float))
rows["n_positives_total_mean"].append(np.repeat(n_pos_mean, P))
rows["auc_mean"].append(np.repeat(auc_m, P))
rows["auc_std"].append(np.repeat(auc_s, P))
rows["auprc_mean"].append(np.repeat(ap_m, P))
rows["auprc_std"].append(np.repeat(ap_s, P))
rows["brier_score_mean"].append(np.repeat(brier_m, P))
rows["brier_score_std"].append(np.repeat(brier_s, P))
rows["precision_at_K_mean"].append(prec_m.T.reshape(-1))
rows["precision_at_K_std"].append(prec_s.T.reshape(-1))
rows["recall_at_K_mean"].append(rec_m.T.reshape(-1))
rows["recall_at_K_std"].append(rec_s.T.reshape(-1))
out = {k: np.concatenate(v, axis=0) for k, v in rows.items()}
df = pd.DataFrame(out)
return df.sort_values(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True
)
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame: def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
"""Aggregate per-bin age evaluation results. """Aggregate per-bin age evaluation results.
@@ -129,34 +487,79 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
) )
) )
# Kept for backward compatibility (e.g., if callers load a CSV and need to
# aggregate). Prefer `aggregate_metrics_columnar` during evaluation.
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"] group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
gb = df_by_bin.groupby(group_keys) df = df_by_bin[df_by_bin["n_samples"] > 0].copy()
if len(df) == 0:
return pd.DataFrame(
columns=[
"agg_type",
"horizon_tau",
"topk_percent",
"cause_id",
"n_mc",
"n_bins_used_mean",
"n_samples_total_mean",
"n_positives_total_mean",
"auc_mean",
"auc_std",
"auprc_mean",
"auprc_std",
"recall_at_K_mean",
"recall_at_K_std",
"precision_at_K_mean",
"precision_at_K_std",
"brier_score_mean",
"brier_score_std",
]
)
try: # Macro: mean over bins
df_mc_macro = gb.apply( df_mc_macro = (
lambda g: _bin_aggregate(g, weighted=False), include_groups=False df.groupby(group_keys, as_index=False)
).reset_index() .agg(
except TypeError: # pandas<2.2 (no include_groups) n_bins_used=("age_bin_id", "nunique"),
df_mc_macro = gb.apply(lambda g: _bin_aggregate( n_samples_total=("n_samples", "sum"),
g, weighted=False)).reset_index() n_positives_total=("n_positives", "sum"),
auc=("auc", "mean"),
auprc=("auprc", "mean"),
recall_at_K=("recall_at_K", "mean"),
precision_at_K=("precision_at_K", "mean"),
brier_score=("brier_score", "mean"),
)
)
df_mc_macro["agg_type"] = "macro" df_mc_macro["agg_type"] = "macro"
try: # Weighted: weighted mean over bins with weights=n_samples, NaN-aware per metric
df_mc_weighted = gb.apply( w = df["n_samples"].astype(float)
lambda g: _bin_aggregate(g, weighted=True), include_groups=False df_w = df.copy()
).reset_index() for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]:
except TypeError: # pandas<2.2 (no include_groups) m = df_w[col].astype(float)
df_mc_weighted = gb.apply( ww = w.where(m.notna(), other=0.0)
lambda g: _bin_aggregate(g, weighted=True)).reset_index() df_w[f"__num_{col}"] = (m.fillna(0.0) * w)
df_mc_weighted["agg_type"] = "weighted" df_w[f"__den_{col}"] = ww
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True) df_mc_w = df_w.groupby(group_keys, as_index=False).agg(
n_bins_used=("age_bin_id", "nunique"),
n_samples_total=("n_samples", "sum"),
n_positives_total=("n_positives", "sum"),
**{f"__num_{c}": (f"__num_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]},
**{f"__den_{c}": (f"__den_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]},
)
for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]:
num = df_mc_w[f"__num_{col}"].astype(float)
den = df_mc_w[f"__den_{col}"].astype(float)
df_mc_w[col] = (num / den).where(den > 0.0, other=float("nan"))
df_mc_w.drop(columns=[f"__num_{col}", f"__den_{col}"], inplace=True)
df_mc_w["agg_type"] = "weighted"
df_mc_binagg = pd.concat([df_mc_macro, df_mc_w], ignore_index=True)
df_agg = ( df_agg = (
df_mc_binagg.groupby( df_mc_binagg.groupby(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False ["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False)
)
.agg( .agg(
n_mc=("mc_idx", "nunique"), n_mc=("mc_idx", "nunique"),
n_bins_used_mean=("n_bins_used", "mean"), n_bins_used_mean=("n_bins_used", "mean"),
@@ -173,10 +576,7 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
brier_score_mean=("brier_score", "mean"), brier_score_mean=("brier_score", "mean"),
brier_score_std=("brier_score", "std"), brier_score_std=("brier_score", "std"),
) )
.sort_values( .sort_values(["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True)
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
ignore_index=True,
)
) )
return df_agg return df_agg
@@ -193,6 +593,7 @@ class EvalAgeConfig:
n_mc: int = 5 n_mc: int = 5
seed: int = 0 seed: int = 0
cause_ids: Optional[Sequence[int]] = None cause_ids: Optional[Sequence[int]] = None
store_per_cause: bool = True
@torch.inference_mode() @torch.inference_mode()
@@ -247,12 +648,17 @@ def evaluate_time_dependent_age_bins(
if cfg.cause_ids is None: if cfg.cause_ids is None:
cause_ids = None cause_ids = None
n_causes_eval = int(n_disease) n_causes_eval = int(n_disease)
cause_id_vec = np.arange(n_causes_eval, dtype=int)
else: else:
cause_ids = torch.tensor( cause_ids = torch.tensor(
list(cfg.cause_ids), dtype=torch.long, device=device) list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel()) n_causes_eval = int(cause_ids.numel())
cause_id_vec = np.asarray(list(cfg.cause_ids), dtype=int)
rows_by_bin: List[Dict[str, float | int]] = [] topk_percents_np = np.asarray(topk_percents, dtype=float)
# Columnar per-(mc,tau,bin) blocks; avoids Python per-row dict appends.
blocks: List[Dict[str, Any]] = []
for mc_idx in range(int(cfg.n_mc)): for mc_idx in range(int(cfg.n_mc)):
global_mc_idx = int(mc_offset) + int(mc_idx) global_mc_idx = int(mc_offset) + int(mc_idx)
@@ -354,29 +760,27 @@ def evaluate_time_dependent_age_bins(
for h_idx, tau_y in enumerate(horizons_years): for h_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins): for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
if len(y_true_mc[h_idx][bin_idx]) == 0: if len(y_true_mc[h_idx][bin_idx]) == 0:
# No samples in this bin for this (mc, tau) # No samples in this bin for this (mc, tau): store a single
for cause_k in range(n_causes_eval): # block with NaN metric vectors.
cause_id = int(cause_k) if cause_ids is None else int( K = int(n_causes_eval)
cfg.cause_ids[cause_k]) P = int(topk_percents_np.size)
for k_percent in topk_percents: blocks.append(
rows_by_bin.append( dict(
dict( mc_idx=global_mc_idx,
mc_idx=global_mc_idx, age_bin_id=bin_idx,
age_bin_id=bin_idx, age_bin_low=float(a_lo),
age_bin_low=float(a_lo), age_bin_high=float(a_hi),
age_bin_high=float(a_hi), horizon_tau=float(tau_y),
horizon_tau=float(tau_y), n_samples=0,
topk_percent=float(k_percent), cause_id=cause_id_vec,
cause_id=cause_id, n_positives=np.zeros((K,), dtype=int),
n_samples=0, auc=np.full((K,), np.nan, dtype=float),
n_positives=0, auprc=np.full((K,), np.nan, dtype=float),
auc=float("nan"), brier_score=np.full((K,), np.nan, dtype=float),
auprc=float("nan"), precision_at_K=np.full((P, K), np.nan, dtype=float),
recall_at_K=float("nan"), recall_at_K=np.full((P, K), np.nan, dtype=float),
precision_at_K=float("nan"), )
brier_score=float("nan"), )
)
)
continue continue
yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0) yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0)
@@ -397,40 +801,52 @@ def evaluate_time_dependent_age_bins(
compute_ici=False, compute_ici=False,
) )
# Move just the metric vectors to CPU once per (mc, tau, bin) # Collect a single columnar block (vectors, not per-row dicts).
# for DataFrame construction. blocks.append(
auc = metrics.auc_per_cause.detach().cpu().numpy() dict(
auprc = metrics.ap_per_cause.detach().cpu().numpy() mc_idx=global_mc_idx,
brier = metrics.brier_per_cause.detach().cpu().numpy() age_bin_id=bin_idx,
n_pos = metrics.n_pos_per_cause.detach().cpu().numpy() age_bin_low=float(a_lo),
prec_at_k = metrics.precision_at_k.detach().cpu().numpy() # (P,K) age_bin_high=float(a_hi),
rec_at_k = metrics.recall_at_k.detach().cpu().numpy() # (P,K) horizon_tau=float(tau_y),
n_samples=int(n_samples),
cause_id=cause_id_vec,
n_positives=metrics.n_pos_per_cause.detach().cpu().numpy().astype(int),
auc=metrics.auc_per_cause.detach().cpu().numpy().astype(float),
auprc=metrics.ap_per_cause.detach().cpu().numpy().astype(float),
brier_score=metrics.brier_per_cause.detach().cpu().numpy().astype(float),
precision_at_K=metrics.precision_at_k.detach().cpu().numpy().astype(float),
recall_at_K=metrics.recall_at_k.detach().cpu().numpy().astype(float),
)
)
for cause_k in range(n_causes_eval): # Aggregation is computed from columnar blocks (fast, no pandas apply).
cause_id = int(cause_k) if cause_ids is None else int( df_agg = aggregate_metrics_columnar(
cfg.cause_ids[cause_k]) blocks,
for p_idx, k_percent in enumerate(topk_percents): topk_percents=topk_percents_np,
rows_by_bin.append( cause_id=cause_id_vec,
dict( )
mc_idx=global_mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
cause_id=cause_id,
n_samples=n_samples,
n_positives=int(n_pos[cause_k]),
auc=float(auc[cause_k]),
auprc=float(auprc[cause_k]),
recall_at_K=float(rec_at_k[p_idx, cause_k]),
precision_at_K=float(prec_at_k[p_idx, cause_k]),
brier_score=float(brier[cause_k]),
)
)
df_by_bin = pd.DataFrame(rows_by_bin) if bool(cfg.store_per_cause):
df_by_bin = _blocks_to_df_by_bin(blocks, topk_percents=topk_percents_np)
df_agg = aggregate_age_bin_results(df_by_bin) else:
df_by_bin = pd.DataFrame(
columns=[
"mc_idx",
"age_bin_id",
"age_bin_low",
"age_bin_high",
"horizon_tau",
"topk_percent",
"cause_id",
"n_samples",
"n_positives",
"auc",
"auprc",
"recall_at_K",
"precision_at_K",
"brier_score",
]
)
return df_by_bin, df_agg return df_by_bin, df_agg