From a637beb220ffff0dbaae8d2479ef9ab0c5077c49 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 16 Jan 2026 17:51:00 +0800 Subject: [PATCH] Add function to drop zero-positive rows and update CSV export logic in age-bin evaluation --- evaluate_age.py | 26 +- evaluation_age_time_dependent.py | 578 ++++++++++++++++++++++++++----- 2 files changed, 521 insertions(+), 83 deletions(-) diff --git a/evaluate_age.py b/evaluate_age.py index 0dea88b..2de473a 100644 --- a/evaluate_age.py +++ b/evaluate_age.py @@ -177,6 +177,7 @@ def _worker_eval_mcs_on_gpu( df_all = pd.concat(frames, ignore_index=True) if len( frames) else pd.DataFrame() + df_all = _drop_zero_positives_rows(df_all, "n_positives") df_all.to_csv(out_path, index=False) queue.put({"ok": True, "out_path": out_path}) except Exception as e: @@ -211,6 +212,20 @@ def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lamb raise ValueError(f"Unsupported loss_type: {loss_type}") +def _drop_zero_positives_rows(df: pd.DataFrame, positive_col: str) -> pd.DataFrame: + """Drop rows where the provided positives column is <= 0. + + Intended to reduce CSV size by omitting (cause, horizon, bin) rows that have + no positives, which otherwise yield undefined/NaN metrics. + """ + if df is None or len(df) == 0: + return df + if positive_col not in df.columns: + return df + pos = pd.to_numeric(df[positive_col], errors="coerce") + return df[pos > 0].copy() + + def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict): if model_type == "delphi_fork": return DelphiFork( @@ -404,8 +419,10 @@ def main() -> None: device=device, ) - df_by_bin.to_csv(out_bin, index=False) - df_agg.to_csv(out_agg, index=False) + df_by_bin_csv = _drop_zero_positives_rows(df_by_bin, "n_positives") + df_agg_csv = _drop_zero_positives_rows(df_agg, "n_positives_total_mean") + df_by_bin_csv.to_csv(out_bin, index=False) + df_agg_csv.to_csv(out_agg, index=False) print(f"Wrote: {out_bin}") print(f"Wrote: {out_agg}") return @@ -464,8 +481,13 @@ def main() -> None: frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)] df_by_bin = pd.concat(frames, ignore_index=True) if len( frames) else pd.DataFrame() + + # Ensure we don't keep zero-positive rows even if a temp file was produced + # by an older version of the worker. + df_by_bin = _drop_zero_positives_rows(df_by_bin, "n_positives") df_agg = aggregate_age_bin_results(df_by_bin) + df_agg = _drop_zero_positives_rows(df_agg, "n_positives_total_mean") df_by_bin.to_csv(out_bin, index=False) df_agg.to_csv(out_agg, index=False) diff --git a/evaluation_age_time_dependent.py b/evaluation_age_time_dependent.py index 8dafb7f..fbc26c2 100644 --- a/evaluation_age_time_dependent.py +++ b/evaluation_age_time_dependent.py @@ -2,7 +2,7 @@ from __future__ import annotations import math from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np import pandas as pd @@ -23,6 +23,364 @@ from utils import ( from torch_metrics import compute_binary_metrics_torch +def _nanmean_np(x: np.ndarray, axis: int = 0) -> np.ndarray: + with np.errstate(invalid="ignore"): + return np.nanmean(x, axis=axis) + + +def _nanstd_np_ddof1(x: np.ndarray, axis: int = 0) -> np.ndarray: + """NaN-aware sample std (ddof=1), matching pandas std() semantics.""" + x = np.asarray(x, dtype=float) + mask = np.isfinite(x) + cnt = mask.sum(axis=axis) + # mean over finite entries + x0 = np.where(mask, x, 0.0) + mean = x0.sum(axis=axis) / np.maximum(cnt, 1) + # sum of squared deviations over finite entries + dev2 = np.where(mask, (x - np.expand_dims(mean, axis=axis)) ** 2, 0.0) + ss = dev2.sum(axis=axis) + denom = cnt - 1 + out = np.sqrt(ss / np.maximum(denom, 1)) + out = np.where(denom > 0, out, np.nan) + return out + + +def _weighted_mean_np(x: np.ndarray, w: np.ndarray, axis: int = 0) -> np.ndarray: + """NaN-aware weighted mean. + + Only bins with finite x contribute to both numerator and denominator. + If denom==0 -> NaN. + """ + x = np.asarray(x, dtype=float) + w = np.asarray(w, dtype=float) + + if axis != 0: + raise ValueError("_weighted_mean_np currently supports axis=0 only") + + # Broadcast weights along trailing dims of x. + while w.ndim < x.ndim: + w = w[..., None] + w = np.broadcast_to(w, x.shape) + + mask = np.isfinite(x) + num = np.where(mask, x * w, 0.0).sum(axis=0) + denom = np.where(mask, w, 0.0).sum(axis=0) + return np.where(denom > 0.0, num / denom, np.nan) + + +def _blocks_to_df_by_bin( + blocks: List[Dict[str, Any]], + *, + topk_percents: np.ndarray, +) -> pd.DataFrame: + """Convert per-block column vectors into the long-format per-bin DataFrame. + + This does a single vectorized reshape per block (cause-major ordering), and + concatenates columns once at the end. + """ + if len(blocks) == 0: + return pd.DataFrame( + columns=[ + "mc_idx", + "age_bin_id", + "age_bin_low", + "age_bin_high", + "horizon_tau", + "topk_percent", + "cause_id", + "n_samples", + "n_positives", + "auc", + "auprc", + "recall_at_K", + "precision_at_K", + "brier_score", + ] + ) + + P = int(topk_percents.size) + + cols: Dict[str, List[np.ndarray]] = { + "mc_idx": [], + "age_bin_id": [], + "age_bin_low": [], + "age_bin_high": [], + "horizon_tau": [], + "topk_percent": [], + "cause_id": [], + "n_samples": [], + "n_positives": [], + "auc": [], + "auprc": [], + "recall_at_K": [], + "precision_at_K": [], + "brier_score": [], + } + + for blk in blocks: + cause_id = np.asarray(blk["cause_id"], dtype=int) + K = int(cause_id.size) + n_rows = K * P + + cols["mc_idx"].append(np.full(n_rows, int(blk["mc_idx"]), dtype=int)) + cols["age_bin_id"].append( + np.full(n_rows, int(blk["age_bin_id"]), dtype=int)) + cols["age_bin_low"].append( + np.full(n_rows, float(blk["age_bin_low"]), dtype=float)) + cols["age_bin_high"].append( + np.full(n_rows, float(blk["age_bin_high"]), dtype=float)) + cols["horizon_tau"].append( + np.full(n_rows, float(blk["horizon_tau"]), dtype=float)) + + cols["cause_id"].append(np.repeat(cause_id, P)) + cols["topk_percent"].append(np.tile(topk_percents.astype(float), K)) + cols["n_samples"].append( + np.full(n_rows, int(blk["n_samples"]), dtype=int)) + + n_pos = np.asarray(blk["n_positives"], dtype=int) + cols["n_positives"].append(np.repeat(n_pos, P)) + + auc = np.asarray(blk["auc"], dtype=float) + auprc = np.asarray(blk["auprc"], dtype=float) + brier = np.asarray(blk["brier_score"], dtype=float) + cols["auc"].append(np.repeat(auc, P)) + cols["auprc"].append(np.repeat(auprc, P)) + cols["brier_score"].append(np.repeat(brier, P)) + + # precision/recall are stored as (P,K); we want cause-major rows, i.e. + # (K,P) then flatten. + prec = np.asarray(blk["precision_at_K"], dtype=float) + rec = np.asarray(blk["recall_at_K"], dtype=float) + if prec.shape != (P, K) or rec.shape != (P, K): + raise ValueError( + f"Expected precision/recall shapes (P,K)=({P},{K}); got {prec.shape} and {rec.shape}" + ) + cols["precision_at_K"].append(prec.T.reshape(-1)) + cols["recall_at_K"].append(rec.T.reshape(-1)) + + out = {k: np.concatenate(v, axis=0) for k, v in cols.items()} + return pd.DataFrame(out) + + +def aggregate_metrics_columnar( + blocks: List[Dict[str, Any]], + *, + topk_percents: np.ndarray, + cause_id: np.ndarray, +) -> pd.DataFrame: + """Aggregate per-bin results across age bins (macro/weighted) and MC (mean/std). + + This is a vectorized, columnar replacement for the old pandas groupby/apply. + Semantics match the previous implementation: + - bins with n_samples==0 are excluded from bin-aggregation + - macro: unweighted mean over bins (NaN-aware) + - weighted: weighted mean over bins using weights=n_samples (NaN-aware) + - across MC: mean/std (ddof=1), NaN-aware + """ + if len(blocks) == 0: + return pd.DataFrame( + columns=[ + "agg_type", + "horizon_tau", + "topk_percent", + "cause_id", + "n_mc", + "n_bins_used_mean", + "n_samples_total_mean", + "n_positives_total_mean", + "auc_mean", + "auc_std", + "auprc_mean", + "auprc_std", + "recall_at_K_mean", + "recall_at_K_std", + "precision_at_K_mean", + "precision_at_K_std", + "brier_score_mean", + "brier_score_std", + ] + ) + + P = int(topk_percents.size) + cause_id = np.asarray(cause_id, dtype=int) + K = int(cause_id.size) + + # Group blocks by (mc_idx, horizon_tau) + keys: List[Tuple[int, float]] = [] + grouped: Dict[Tuple[int, float], List[Dict[str, Any]]] = {} + for blk in blocks: + key = (int(blk["mc_idx"]), float(blk["horizon_tau"])) + if key not in grouped: + grouped[key] = [] + keys.append(key) + grouped[key].append(blk) + + mc_vals = sorted({k[0] for k in keys}) + tau_vals = sorted({k[1] for k in keys}) + M = len(mc_vals) + T = len(tau_vals) + + mc_index = {mc: i for i, mc in enumerate(mc_vals)} + tau_index = {tau: i for i, tau in enumerate(tau_vals)} + + # Per (agg_type, mc, tau): store arrays + # metrics: (M,T,K) and (M,T,P,K) + auc_macro = np.full((M, T, K), np.nan, dtype=float) + auc_weighted = np.full((M, T, K), np.nan, dtype=float) + ap_macro = np.full((M, T, K), np.nan, dtype=float) + ap_weighted = np.full((M, T, K), np.nan, dtype=float) + brier_macro = np.full((M, T, K), np.nan, dtype=float) + brier_weighted = np.full((M, T, K), np.nan, dtype=float) + + prec_macro = np.full((M, T, P, K), np.nan, dtype=float) + prec_weighted = np.full((M, T, P, K), np.nan, dtype=float) + rec_macro = np.full((M, T, P, K), np.nan, dtype=float) + rec_weighted = np.full((M, T, P, K), np.nan, dtype=float) + + n_bins_used = np.zeros((M, T), dtype=float) + n_samples_total = np.zeros((M, T), dtype=float) + n_pos_total = np.zeros((M, T, K), dtype=float) + + for (mc, tau), blks in grouped.items(): + mi = mc_index[mc] + ti = tau_index[tau] + + # keep only bins with n_samples>0 + blks_nz = [b for b in blks if int(b["n_samples"]) > 0] + if len(blks_nz) == 0: + n_bins_used[mi, ti] = 0.0 + n_samples_total[mi, ti] = 0.0 + n_pos_total[mi, ti, :] = 0.0 + continue + + w = np.asarray([int(b["n_samples"]) + for b in blks_nz], dtype=float) # (B,) + n_bins_used[mi, ti] = float(len(w)) + n_samples_total[mi, ti] = float(w.sum()) + + npos = np.stack([np.asarray(b["n_positives"], dtype=float) + for b in blks_nz], axis=0) # (B,K) + n_pos_total[mi, ti, :] = npos.sum(axis=0) + + auc_b = np.stack([np.asarray(b["auc"], dtype=float) + for b in blks_nz], axis=0) # (B,K) + ap_b = np.stack([np.asarray(b["auprc"], dtype=float) + for b in blks_nz], axis=0) + brier_b = np.stack([np.asarray(b["brier_score"], dtype=float) + for b in blks_nz], axis=0) + + auc_macro[mi, ti, :] = _nanmean_np(auc_b, axis=0) + ap_macro[mi, ti, :] = _nanmean_np(ap_b, axis=0) + brier_macro[mi, ti, :] = _nanmean_np(brier_b, axis=0) + + auc_weighted[mi, ti, :] = _weighted_mean_np(auc_b, w, axis=0) + ap_weighted[mi, ti, :] = _weighted_mean_np(ap_b, w, axis=0) + brier_weighted[mi, ti, :] = _weighted_mean_np(brier_b, w, axis=0) + + prec_b = np.stack([np.asarray(b["precision_at_K"], dtype=float) + for b in blks_nz], axis=0) # (B,P,K) + rec_b = np.stack([np.asarray(b["recall_at_K"], dtype=float) + for b in blks_nz], axis=0) + + # macro mean over bins + prec_macro[mi, ti, :, :] = _nanmean_np(prec_b, axis=0) + rec_macro[mi, ti, :, :] = _nanmean_np(rec_b, axis=0) + + # weighted mean over bins (weights along bin axis) + w3 = w.reshape(-1, 1, 1) + prec_weighted[mi, ti, :, :] = _weighted_mean_np(prec_b, w3, axis=0) + rec_weighted[mi, ti, :, :] = _weighted_mean_np(rec_b, w3, axis=0) + + # Across-MC aggregation (mean/std), then emit long-format df keyed by + # (agg_type, horizon_tau, topk_percent, cause_id) + rows: Dict[str, List[np.ndarray]] = { + "agg_type": [], + "horizon_tau": [], + "topk_percent": [], + "cause_id": [], + "n_mc": [], + "n_bins_used_mean": [], + "n_samples_total_mean": [], + "n_positives_total_mean": [], + "auc_mean": [], + "auc_std": [], + "auprc_mean": [], + "auprc_std": [], + "recall_at_K_mean": [], + "recall_at_K_std": [], + "precision_at_K_mean": [], + "precision_at_K_std": [], + "brier_score_mean": [], + "brier_score_std": [], + } + + cause_long = np.repeat(cause_id, P) + topk_long = np.tile(topk_percents.astype(float), K) + n_mc_val = float(M) + + for ti, tau in enumerate(tau_vals): + # scalar totals (repeat across causes/topk) + n_bins_mean = float( + np.mean(n_bins_used[:, ti])) if M > 0 else float("nan") + n_samp_mean = float( + np.mean(n_samples_total[:, ti])) if M > 0 else float("nan") + n_pos_mean = _nanmean_np(n_pos_total[:, ti, :], axis=0) # (K,) + + for agg_type in ("macro", "weighted"): + if agg_type == "macro": + auc_m = _nanmean_np(auc_macro[:, ti, :], axis=0) + auc_s = _nanstd_np_ddof1(auc_macro[:, ti, :], axis=0) + ap_m = _nanmean_np(ap_macro[:, ti, :], axis=0) + ap_s = _nanstd_np_ddof1(ap_macro[:, ti, :], axis=0) + brier_m = _nanmean_np(brier_macro[:, ti, :], axis=0) + brier_s = _nanstd_np_ddof1(brier_macro[:, ti, :], axis=0) + prec_m = _nanmean_np(prec_macro[:, ti, :, :], axis=0) # (P,K) + prec_s = _nanstd_np_ddof1(prec_macro[:, ti, :, :], axis=0) + rec_m = _nanmean_np(rec_macro[:, ti, :, :], axis=0) + rec_s = _nanstd_np_ddof1(rec_macro[:, ti, :, :], axis=0) + else: + auc_m = _nanmean_np(auc_weighted[:, ti, :], axis=0) + auc_s = _nanstd_np_ddof1(auc_weighted[:, ti, :], axis=0) + ap_m = _nanmean_np(ap_weighted[:, ti, :], axis=0) + ap_s = _nanstd_np_ddof1(ap_weighted[:, ti, :], axis=0) + brier_m = _nanmean_np(brier_weighted[:, ti, :], axis=0) + brier_s = _nanstd_np_ddof1(brier_weighted[:, ti, :], axis=0) + prec_m = _nanmean_np(prec_weighted[:, ti, :, :], axis=0) + prec_s = _nanstd_np_ddof1(prec_weighted[:, ti, :, :], axis=0) + rec_m = _nanmean_np(rec_weighted[:, ti, :, :], axis=0) + rec_s = _nanstd_np_ddof1(rec_weighted[:, ti, :, :], axis=0) + + n_rows = K * P + rows["agg_type"].append(np.full(n_rows, agg_type, dtype=object)) + rows["horizon_tau"].append(np.full(n_rows, float(tau), dtype=float)) + rows["topk_percent"].append(topk_long) + rows["cause_id"].append(cause_long) + rows["n_mc"].append(np.full(n_rows, n_mc_val, dtype=float)) + rows["n_bins_used_mean"].append( + np.full(n_rows, n_bins_mean, dtype=float)) + rows["n_samples_total_mean"].append( + np.full(n_rows, n_samp_mean, dtype=float)) + rows["n_positives_total_mean"].append(np.repeat(n_pos_mean, P)) + + rows["auc_mean"].append(np.repeat(auc_m, P)) + rows["auc_std"].append(np.repeat(auc_s, P)) + rows["auprc_mean"].append(np.repeat(ap_m, P)) + rows["auprc_std"].append(np.repeat(ap_s, P)) + rows["brier_score_mean"].append(np.repeat(brier_m, P)) + rows["brier_score_std"].append(np.repeat(brier_s, P)) + + rows["precision_at_K_mean"].append(prec_m.T.reshape(-1)) + rows["precision_at_K_std"].append(prec_s.T.reshape(-1)) + rows["recall_at_K_mean"].append(rec_m.T.reshape(-1)) + rows["recall_at_K_std"].append(rec_s.T.reshape(-1)) + + out = {k: np.concatenate(v, axis=0) for k, v in rows.items()} + df = pd.DataFrame(out) + return df.sort_values( + ["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True + ) + + def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame: """Aggregate per-bin age evaluation results. @@ -129,34 +487,79 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame: ) ) + # Kept for backward compatibility (e.g., if callers load a CSV and need to + # aggregate). Prefer `aggregate_metrics_columnar` during evaluation. group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"] - gb = df_by_bin.groupby(group_keys) + df = df_by_bin[df_by_bin["n_samples"] > 0].copy() + if len(df) == 0: + return pd.DataFrame( + columns=[ + "agg_type", + "horizon_tau", + "topk_percent", + "cause_id", + "n_mc", + "n_bins_used_mean", + "n_samples_total_mean", + "n_positives_total_mean", + "auc_mean", + "auc_std", + "auprc_mean", + "auprc_std", + "recall_at_K_mean", + "recall_at_K_std", + "precision_at_K_mean", + "precision_at_K_std", + "brier_score_mean", + "brier_score_std", + ] + ) - try: - df_mc_macro = gb.apply( - lambda g: _bin_aggregate(g, weighted=False), include_groups=False - ).reset_index() - except TypeError: # pandas<2.2 (no include_groups) - df_mc_macro = gb.apply(lambda g: _bin_aggregate( - g, weighted=False)).reset_index() + # Macro: mean over bins + df_mc_macro = ( + df.groupby(group_keys, as_index=False) + .agg( + n_bins_used=("age_bin_id", "nunique"), + n_samples_total=("n_samples", "sum"), + n_positives_total=("n_positives", "sum"), + auc=("auc", "mean"), + auprc=("auprc", "mean"), + recall_at_K=("recall_at_K", "mean"), + precision_at_K=("precision_at_K", "mean"), + brier_score=("brier_score", "mean"), + ) + ) df_mc_macro["agg_type"] = "macro" - try: - df_mc_weighted = gb.apply( - lambda g: _bin_aggregate(g, weighted=True), include_groups=False - ).reset_index() - except TypeError: # pandas<2.2 (no include_groups) - df_mc_weighted = gb.apply( - lambda g: _bin_aggregate(g, weighted=True)).reset_index() - df_mc_weighted["agg_type"] = "weighted" + # Weighted: weighted mean over bins with weights=n_samples, NaN-aware per metric + w = df["n_samples"].astype(float) + df_w = df.copy() + for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]: + m = df_w[col].astype(float) + ww = w.where(m.notna(), other=0.0) + df_w[f"__num_{col}"] = (m.fillna(0.0) * w) + df_w[f"__den_{col}"] = ww - df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True) + df_mc_w = df_w.groupby(group_keys, as_index=False).agg( + n_bins_used=("age_bin_id", "nunique"), + n_samples_total=("n_samples", "sum"), + n_positives_total=("n_positives", "sum"), + **{f"__num_{c}": (f"__num_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]}, + **{f"__den_{c}": (f"__den_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]}, + ) + for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]: + num = df_mc_w[f"__num_{col}"].astype(float) + den = df_mc_w[f"__den_{col}"].astype(float) + df_mc_w[col] = (num / den).where(den > 0.0, other=float("nan")) + df_mc_w.drop(columns=[f"__num_{col}", f"__den_{col}"], inplace=True) + df_mc_w["agg_type"] = "weighted" + + df_mc_binagg = pd.concat([df_mc_macro, df_mc_w], ignore_index=True) df_agg = ( df_mc_binagg.groupby( - ["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False - ) + ["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False) .agg( n_mc=("mc_idx", "nunique"), n_bins_used_mean=("n_bins_used", "mean"), @@ -173,10 +576,7 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame: brier_score_mean=("brier_score", "mean"), brier_score_std=("brier_score", "std"), ) - .sort_values( - ["agg_type", "horizon_tau", "topk_percent", "cause_id"], - ignore_index=True, - ) + .sort_values(["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True) ) return df_agg @@ -193,6 +593,7 @@ class EvalAgeConfig: n_mc: int = 5 seed: int = 0 cause_ids: Optional[Sequence[int]] = None + store_per_cause: bool = True @torch.inference_mode() @@ -247,12 +648,17 @@ def evaluate_time_dependent_age_bins( if cfg.cause_ids is None: cause_ids = None n_causes_eval = int(n_disease) + cause_id_vec = np.arange(n_causes_eval, dtype=int) else: cause_ids = torch.tensor( list(cfg.cause_ids), dtype=torch.long, device=device) n_causes_eval = int(cause_ids.numel()) + cause_id_vec = np.asarray(list(cfg.cause_ids), dtype=int) - rows_by_bin: List[Dict[str, float | int]] = [] + topk_percents_np = np.asarray(topk_percents, dtype=float) + + # Columnar per-(mc,tau,bin) blocks; avoids Python per-row dict appends. + blocks: List[Dict[str, Any]] = [] for mc_idx in range(int(cfg.n_mc)): global_mc_idx = int(mc_offset) + int(mc_idx) @@ -354,29 +760,27 @@ def evaluate_time_dependent_age_bins( for h_idx, tau_y in enumerate(horizons_years): for bin_idx, (a_lo, a_hi) in enumerate(age_bins): if len(y_true_mc[h_idx][bin_idx]) == 0: - # No samples in this bin for this (mc, tau) - for cause_k in range(n_causes_eval): - cause_id = int(cause_k) if cause_ids is None else int( - cfg.cause_ids[cause_k]) - for k_percent in topk_percents: - rows_by_bin.append( - dict( - mc_idx=global_mc_idx, - age_bin_id=bin_idx, - age_bin_low=float(a_lo), - age_bin_high=float(a_hi), - horizon_tau=float(tau_y), - topk_percent=float(k_percent), - cause_id=cause_id, - n_samples=0, - n_positives=0, - auc=float("nan"), - auprc=float("nan"), - recall_at_K=float("nan"), - precision_at_K=float("nan"), - brier_score=float("nan"), - ) - ) + # No samples in this bin for this (mc, tau): store a single + # block with NaN metric vectors. + K = int(n_causes_eval) + P = int(topk_percents_np.size) + blocks.append( + dict( + mc_idx=global_mc_idx, + age_bin_id=bin_idx, + age_bin_low=float(a_lo), + age_bin_high=float(a_hi), + horizon_tau=float(tau_y), + n_samples=0, + cause_id=cause_id_vec, + n_positives=np.zeros((K,), dtype=int), + auc=np.full((K,), np.nan, dtype=float), + auprc=np.full((K,), np.nan, dtype=float), + brier_score=np.full((K,), np.nan, dtype=float), + precision_at_K=np.full((P, K), np.nan, dtype=float), + recall_at_K=np.full((P, K), np.nan, dtype=float), + ) + ) continue yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0) @@ -397,40 +801,52 @@ def evaluate_time_dependent_age_bins( compute_ici=False, ) - # Move just the metric vectors to CPU once per (mc, tau, bin) - # for DataFrame construction. - auc = metrics.auc_per_cause.detach().cpu().numpy() - auprc = metrics.ap_per_cause.detach().cpu().numpy() - brier = metrics.brier_per_cause.detach().cpu().numpy() - n_pos = metrics.n_pos_per_cause.detach().cpu().numpy() - prec_at_k = metrics.precision_at_k.detach().cpu().numpy() # (P,K) - rec_at_k = metrics.recall_at_k.detach().cpu().numpy() # (P,K) + # Collect a single columnar block (vectors, not per-row dicts). + blocks.append( + dict( + mc_idx=global_mc_idx, + age_bin_id=bin_idx, + age_bin_low=float(a_lo), + age_bin_high=float(a_hi), + horizon_tau=float(tau_y), + n_samples=int(n_samples), + cause_id=cause_id_vec, + n_positives=metrics.n_pos_per_cause.detach().cpu().numpy().astype(int), + auc=metrics.auc_per_cause.detach().cpu().numpy().astype(float), + auprc=metrics.ap_per_cause.detach().cpu().numpy().astype(float), + brier_score=metrics.brier_per_cause.detach().cpu().numpy().astype(float), + precision_at_K=metrics.precision_at_k.detach().cpu().numpy().astype(float), + recall_at_K=metrics.recall_at_k.detach().cpu().numpy().astype(float), + ) + ) - for cause_k in range(n_causes_eval): - cause_id = int(cause_k) if cause_ids is None else int( - cfg.cause_ids[cause_k]) - for p_idx, k_percent in enumerate(topk_percents): - rows_by_bin.append( - dict( - mc_idx=global_mc_idx, - age_bin_id=bin_idx, - age_bin_low=float(a_lo), - age_bin_high=float(a_hi), - horizon_tau=float(tau_y), - topk_percent=float(k_percent), - cause_id=cause_id, - n_samples=n_samples, - n_positives=int(n_pos[cause_k]), - auc=float(auc[cause_k]), - auprc=float(auprc[cause_k]), - recall_at_K=float(rec_at_k[p_idx, cause_k]), - precision_at_K=float(prec_at_k[p_idx, cause_k]), - brier_score=float(brier[cause_k]), - ) - ) + # Aggregation is computed from columnar blocks (fast, no pandas apply). + df_agg = aggregate_metrics_columnar( + blocks, + topk_percents=topk_percents_np, + cause_id=cause_id_vec, + ) - df_by_bin = pd.DataFrame(rows_by_bin) - - df_agg = aggregate_age_bin_results(df_by_bin) + if bool(cfg.store_per_cause): + df_by_bin = _blocks_to_df_by_bin(blocks, topk_percents=topk_percents_np) + else: + df_by_bin = pd.DataFrame( + columns=[ + "mc_idx", + "age_bin_id", + "age_bin_low", + "age_bin_high", + "horizon_tau", + "topk_percent", + "cause_id", + "n_samples", + "n_positives", + "auc", + "auprc", + "recall_at_K", + "precision_at_K", + "brier_score", + ] + ) return df_by_bin, df_agg