Add function to drop zero-positive rows and update CSV export logic in age-bin evaluation
This commit is contained in:
@@ -177,6 +177,7 @@ def _worker_eval_mcs_on_gpu(
|
|||||||
|
|
||||||
df_all = pd.concat(frames, ignore_index=True) if len(
|
df_all = pd.concat(frames, ignore_index=True) if len(
|
||||||
frames) else pd.DataFrame()
|
frames) else pd.DataFrame()
|
||||||
|
df_all = _drop_zero_positives_rows(df_all, "n_positives")
|
||||||
df_all.to_csv(out_path, index=False)
|
df_all.to_csv(out_path, index=False)
|
||||||
queue.put({"ok": True, "out_path": out_path})
|
queue.put({"ok": True, "out_path": out_path})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -211,6 +212,20 @@ def build_criterion_and_out_dims(loss_type: str, n_disease: int, bin_edges, lamb
|
|||||||
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def _drop_zero_positives_rows(df: pd.DataFrame, positive_col: str) -> pd.DataFrame:
|
||||||
|
"""Drop rows where the provided positives column is <= 0.
|
||||||
|
|
||||||
|
Intended to reduce CSV size by omitting (cause, horizon, bin) rows that have
|
||||||
|
no positives, which otherwise yield undefined/NaN metrics.
|
||||||
|
"""
|
||||||
|
if df is None or len(df) == 0:
|
||||||
|
return df
|
||||||
|
if positive_col not in df.columns:
|
||||||
|
return df
|
||||||
|
pos = pd.to_numeric(df[positive_col], errors="coerce")
|
||||||
|
return df[pos > 0].copy()
|
||||||
|
|
||||||
|
|
||||||
def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict):
|
def build_model(model_type: str, *, dataset: HealthDataset, cfg: dict):
|
||||||
if model_type == "delphi_fork":
|
if model_type == "delphi_fork":
|
||||||
return DelphiFork(
|
return DelphiFork(
|
||||||
@@ -404,8 +419,10 @@ def main() -> None:
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
df_by_bin.to_csv(out_bin, index=False)
|
df_by_bin_csv = _drop_zero_positives_rows(df_by_bin, "n_positives")
|
||||||
df_agg.to_csv(out_agg, index=False)
|
df_agg_csv = _drop_zero_positives_rows(df_agg, "n_positives_total_mean")
|
||||||
|
df_by_bin_csv.to_csv(out_bin, index=False)
|
||||||
|
df_agg_csv.to_csv(out_agg, index=False)
|
||||||
print(f"Wrote: {out_bin}")
|
print(f"Wrote: {out_bin}")
|
||||||
print(f"Wrote: {out_agg}")
|
print(f"Wrote: {out_agg}")
|
||||||
return
|
return
|
||||||
@@ -464,8 +481,13 @@ def main() -> None:
|
|||||||
frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)]
|
frames = [pd.read_csv(p) for p in tmp_paths if os.path.exists(p)]
|
||||||
df_by_bin = pd.concat(frames, ignore_index=True) if len(
|
df_by_bin = pd.concat(frames, ignore_index=True) if len(
|
||||||
frames) else pd.DataFrame()
|
frames) else pd.DataFrame()
|
||||||
|
|
||||||
|
# Ensure we don't keep zero-positive rows even if a temp file was produced
|
||||||
|
# by an older version of the worker.
|
||||||
|
df_by_bin = _drop_zero_positives_rows(df_by_bin, "n_positives")
|
||||||
df_agg = aggregate_age_bin_results(df_by_bin)
|
df_agg = aggregate_age_bin_results(df_by_bin)
|
||||||
|
|
||||||
|
df_agg = _drop_zero_positives_rows(df_agg, "n_positives_total_mean")
|
||||||
df_by_bin.to_csv(out_bin, index=False)
|
df_by_bin.to_csv(out_bin, index=False)
|
||||||
df_agg.to_csv(out_agg, index=False)
|
df_agg.to_csv(out_agg, index=False)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -23,6 +23,364 @@ from utils import (
|
|||||||
from torch_metrics import compute_binary_metrics_torch
|
from torch_metrics import compute_binary_metrics_torch
|
||||||
|
|
||||||
|
|
||||||
|
def _nanmean_np(x: np.ndarray, axis: int = 0) -> np.ndarray:
|
||||||
|
with np.errstate(invalid="ignore"):
|
||||||
|
return np.nanmean(x, axis=axis)
|
||||||
|
|
||||||
|
|
||||||
|
def _nanstd_np_ddof1(x: np.ndarray, axis: int = 0) -> np.ndarray:
|
||||||
|
"""NaN-aware sample std (ddof=1), matching pandas std() semantics."""
|
||||||
|
x = np.asarray(x, dtype=float)
|
||||||
|
mask = np.isfinite(x)
|
||||||
|
cnt = mask.sum(axis=axis)
|
||||||
|
# mean over finite entries
|
||||||
|
x0 = np.where(mask, x, 0.0)
|
||||||
|
mean = x0.sum(axis=axis) / np.maximum(cnt, 1)
|
||||||
|
# sum of squared deviations over finite entries
|
||||||
|
dev2 = np.where(mask, (x - np.expand_dims(mean, axis=axis)) ** 2, 0.0)
|
||||||
|
ss = dev2.sum(axis=axis)
|
||||||
|
denom = cnt - 1
|
||||||
|
out = np.sqrt(ss / np.maximum(denom, 1))
|
||||||
|
out = np.where(denom > 0, out, np.nan)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _weighted_mean_np(x: np.ndarray, w: np.ndarray, axis: int = 0) -> np.ndarray:
|
||||||
|
"""NaN-aware weighted mean.
|
||||||
|
|
||||||
|
Only bins with finite x contribute to both numerator and denominator.
|
||||||
|
If denom==0 -> NaN.
|
||||||
|
"""
|
||||||
|
x = np.asarray(x, dtype=float)
|
||||||
|
w = np.asarray(w, dtype=float)
|
||||||
|
|
||||||
|
if axis != 0:
|
||||||
|
raise ValueError("_weighted_mean_np currently supports axis=0 only")
|
||||||
|
|
||||||
|
# Broadcast weights along trailing dims of x.
|
||||||
|
while w.ndim < x.ndim:
|
||||||
|
w = w[..., None]
|
||||||
|
w = np.broadcast_to(w, x.shape)
|
||||||
|
|
||||||
|
mask = np.isfinite(x)
|
||||||
|
num = np.where(mask, x * w, 0.0).sum(axis=0)
|
||||||
|
denom = np.where(mask, w, 0.0).sum(axis=0)
|
||||||
|
return np.where(denom > 0.0, num / denom, np.nan)
|
||||||
|
|
||||||
|
|
||||||
|
def _blocks_to_df_by_bin(
|
||||||
|
blocks: List[Dict[str, Any]],
|
||||||
|
*,
|
||||||
|
topk_percents: np.ndarray,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Convert per-block column vectors into the long-format per-bin DataFrame.
|
||||||
|
|
||||||
|
This does a single vectorized reshape per block (cause-major ordering), and
|
||||||
|
concatenates columns once at the end.
|
||||||
|
"""
|
||||||
|
if len(blocks) == 0:
|
||||||
|
return pd.DataFrame(
|
||||||
|
columns=[
|
||||||
|
"mc_idx",
|
||||||
|
"age_bin_id",
|
||||||
|
"age_bin_low",
|
||||||
|
"age_bin_high",
|
||||||
|
"horizon_tau",
|
||||||
|
"topk_percent",
|
||||||
|
"cause_id",
|
||||||
|
"n_samples",
|
||||||
|
"n_positives",
|
||||||
|
"auc",
|
||||||
|
"auprc",
|
||||||
|
"recall_at_K",
|
||||||
|
"precision_at_K",
|
||||||
|
"brier_score",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
P = int(topk_percents.size)
|
||||||
|
|
||||||
|
cols: Dict[str, List[np.ndarray]] = {
|
||||||
|
"mc_idx": [],
|
||||||
|
"age_bin_id": [],
|
||||||
|
"age_bin_low": [],
|
||||||
|
"age_bin_high": [],
|
||||||
|
"horizon_tau": [],
|
||||||
|
"topk_percent": [],
|
||||||
|
"cause_id": [],
|
||||||
|
"n_samples": [],
|
||||||
|
"n_positives": [],
|
||||||
|
"auc": [],
|
||||||
|
"auprc": [],
|
||||||
|
"recall_at_K": [],
|
||||||
|
"precision_at_K": [],
|
||||||
|
"brier_score": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for blk in blocks:
|
||||||
|
cause_id = np.asarray(blk["cause_id"], dtype=int)
|
||||||
|
K = int(cause_id.size)
|
||||||
|
n_rows = K * P
|
||||||
|
|
||||||
|
cols["mc_idx"].append(np.full(n_rows, int(blk["mc_idx"]), dtype=int))
|
||||||
|
cols["age_bin_id"].append(
|
||||||
|
np.full(n_rows, int(blk["age_bin_id"]), dtype=int))
|
||||||
|
cols["age_bin_low"].append(
|
||||||
|
np.full(n_rows, float(blk["age_bin_low"]), dtype=float))
|
||||||
|
cols["age_bin_high"].append(
|
||||||
|
np.full(n_rows, float(blk["age_bin_high"]), dtype=float))
|
||||||
|
cols["horizon_tau"].append(
|
||||||
|
np.full(n_rows, float(blk["horizon_tau"]), dtype=float))
|
||||||
|
|
||||||
|
cols["cause_id"].append(np.repeat(cause_id, P))
|
||||||
|
cols["topk_percent"].append(np.tile(topk_percents.astype(float), K))
|
||||||
|
cols["n_samples"].append(
|
||||||
|
np.full(n_rows, int(blk["n_samples"]), dtype=int))
|
||||||
|
|
||||||
|
n_pos = np.asarray(blk["n_positives"], dtype=int)
|
||||||
|
cols["n_positives"].append(np.repeat(n_pos, P))
|
||||||
|
|
||||||
|
auc = np.asarray(blk["auc"], dtype=float)
|
||||||
|
auprc = np.asarray(blk["auprc"], dtype=float)
|
||||||
|
brier = np.asarray(blk["brier_score"], dtype=float)
|
||||||
|
cols["auc"].append(np.repeat(auc, P))
|
||||||
|
cols["auprc"].append(np.repeat(auprc, P))
|
||||||
|
cols["brier_score"].append(np.repeat(brier, P))
|
||||||
|
|
||||||
|
# precision/recall are stored as (P,K); we want cause-major rows, i.e.
|
||||||
|
# (K,P) then flatten.
|
||||||
|
prec = np.asarray(blk["precision_at_K"], dtype=float)
|
||||||
|
rec = np.asarray(blk["recall_at_K"], dtype=float)
|
||||||
|
if prec.shape != (P, K) or rec.shape != (P, K):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected precision/recall shapes (P,K)=({P},{K}); got {prec.shape} and {rec.shape}"
|
||||||
|
)
|
||||||
|
cols["precision_at_K"].append(prec.T.reshape(-1))
|
||||||
|
cols["recall_at_K"].append(rec.T.reshape(-1))
|
||||||
|
|
||||||
|
out = {k: np.concatenate(v, axis=0) for k, v in cols.items()}
|
||||||
|
return pd.DataFrame(out)
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_metrics_columnar(
|
||||||
|
blocks: List[Dict[str, Any]],
|
||||||
|
*,
|
||||||
|
topk_percents: np.ndarray,
|
||||||
|
cause_id: np.ndarray,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Aggregate per-bin results across age bins (macro/weighted) and MC (mean/std).
|
||||||
|
|
||||||
|
This is a vectorized, columnar replacement for the old pandas groupby/apply.
|
||||||
|
Semantics match the previous implementation:
|
||||||
|
- bins with n_samples==0 are excluded from bin-aggregation
|
||||||
|
- macro: unweighted mean over bins (NaN-aware)
|
||||||
|
- weighted: weighted mean over bins using weights=n_samples (NaN-aware)
|
||||||
|
- across MC: mean/std (ddof=1), NaN-aware
|
||||||
|
"""
|
||||||
|
if len(blocks) == 0:
|
||||||
|
return pd.DataFrame(
|
||||||
|
columns=[
|
||||||
|
"agg_type",
|
||||||
|
"horizon_tau",
|
||||||
|
"topk_percent",
|
||||||
|
"cause_id",
|
||||||
|
"n_mc",
|
||||||
|
"n_bins_used_mean",
|
||||||
|
"n_samples_total_mean",
|
||||||
|
"n_positives_total_mean",
|
||||||
|
"auc_mean",
|
||||||
|
"auc_std",
|
||||||
|
"auprc_mean",
|
||||||
|
"auprc_std",
|
||||||
|
"recall_at_K_mean",
|
||||||
|
"recall_at_K_std",
|
||||||
|
"precision_at_K_mean",
|
||||||
|
"precision_at_K_std",
|
||||||
|
"brier_score_mean",
|
||||||
|
"brier_score_std",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
P = int(topk_percents.size)
|
||||||
|
cause_id = np.asarray(cause_id, dtype=int)
|
||||||
|
K = int(cause_id.size)
|
||||||
|
|
||||||
|
# Group blocks by (mc_idx, horizon_tau)
|
||||||
|
keys: List[Tuple[int, float]] = []
|
||||||
|
grouped: Dict[Tuple[int, float], List[Dict[str, Any]]] = {}
|
||||||
|
for blk in blocks:
|
||||||
|
key = (int(blk["mc_idx"]), float(blk["horizon_tau"]))
|
||||||
|
if key not in grouped:
|
||||||
|
grouped[key] = []
|
||||||
|
keys.append(key)
|
||||||
|
grouped[key].append(blk)
|
||||||
|
|
||||||
|
mc_vals = sorted({k[0] for k in keys})
|
||||||
|
tau_vals = sorted({k[1] for k in keys})
|
||||||
|
M = len(mc_vals)
|
||||||
|
T = len(tau_vals)
|
||||||
|
|
||||||
|
mc_index = {mc: i for i, mc in enumerate(mc_vals)}
|
||||||
|
tau_index = {tau: i for i, tau in enumerate(tau_vals)}
|
||||||
|
|
||||||
|
# Per (agg_type, mc, tau): store arrays
|
||||||
|
# metrics: (M,T,K) and (M,T,P,K)
|
||||||
|
auc_macro = np.full((M, T, K), np.nan, dtype=float)
|
||||||
|
auc_weighted = np.full((M, T, K), np.nan, dtype=float)
|
||||||
|
ap_macro = np.full((M, T, K), np.nan, dtype=float)
|
||||||
|
ap_weighted = np.full((M, T, K), np.nan, dtype=float)
|
||||||
|
brier_macro = np.full((M, T, K), np.nan, dtype=float)
|
||||||
|
brier_weighted = np.full((M, T, K), np.nan, dtype=float)
|
||||||
|
|
||||||
|
prec_macro = np.full((M, T, P, K), np.nan, dtype=float)
|
||||||
|
prec_weighted = np.full((M, T, P, K), np.nan, dtype=float)
|
||||||
|
rec_macro = np.full((M, T, P, K), np.nan, dtype=float)
|
||||||
|
rec_weighted = np.full((M, T, P, K), np.nan, dtype=float)
|
||||||
|
|
||||||
|
n_bins_used = np.zeros((M, T), dtype=float)
|
||||||
|
n_samples_total = np.zeros((M, T), dtype=float)
|
||||||
|
n_pos_total = np.zeros((M, T, K), dtype=float)
|
||||||
|
|
||||||
|
for (mc, tau), blks in grouped.items():
|
||||||
|
mi = mc_index[mc]
|
||||||
|
ti = tau_index[tau]
|
||||||
|
|
||||||
|
# keep only bins with n_samples>0
|
||||||
|
blks_nz = [b for b in blks if int(b["n_samples"]) > 0]
|
||||||
|
if len(blks_nz) == 0:
|
||||||
|
n_bins_used[mi, ti] = 0.0
|
||||||
|
n_samples_total[mi, ti] = 0.0
|
||||||
|
n_pos_total[mi, ti, :] = 0.0
|
||||||
|
continue
|
||||||
|
|
||||||
|
w = np.asarray([int(b["n_samples"])
|
||||||
|
for b in blks_nz], dtype=float) # (B,)
|
||||||
|
n_bins_used[mi, ti] = float(len(w))
|
||||||
|
n_samples_total[mi, ti] = float(w.sum())
|
||||||
|
|
||||||
|
npos = np.stack([np.asarray(b["n_positives"], dtype=float)
|
||||||
|
for b in blks_nz], axis=0) # (B,K)
|
||||||
|
n_pos_total[mi, ti, :] = npos.sum(axis=0)
|
||||||
|
|
||||||
|
auc_b = np.stack([np.asarray(b["auc"], dtype=float)
|
||||||
|
for b in blks_nz], axis=0) # (B,K)
|
||||||
|
ap_b = np.stack([np.asarray(b["auprc"], dtype=float)
|
||||||
|
for b in blks_nz], axis=0)
|
||||||
|
brier_b = np.stack([np.asarray(b["brier_score"], dtype=float)
|
||||||
|
for b in blks_nz], axis=0)
|
||||||
|
|
||||||
|
auc_macro[mi, ti, :] = _nanmean_np(auc_b, axis=0)
|
||||||
|
ap_macro[mi, ti, :] = _nanmean_np(ap_b, axis=0)
|
||||||
|
brier_macro[mi, ti, :] = _nanmean_np(brier_b, axis=0)
|
||||||
|
|
||||||
|
auc_weighted[mi, ti, :] = _weighted_mean_np(auc_b, w, axis=0)
|
||||||
|
ap_weighted[mi, ti, :] = _weighted_mean_np(ap_b, w, axis=0)
|
||||||
|
brier_weighted[mi, ti, :] = _weighted_mean_np(brier_b, w, axis=0)
|
||||||
|
|
||||||
|
prec_b = np.stack([np.asarray(b["precision_at_K"], dtype=float)
|
||||||
|
for b in blks_nz], axis=0) # (B,P,K)
|
||||||
|
rec_b = np.stack([np.asarray(b["recall_at_K"], dtype=float)
|
||||||
|
for b in blks_nz], axis=0)
|
||||||
|
|
||||||
|
# macro mean over bins
|
||||||
|
prec_macro[mi, ti, :, :] = _nanmean_np(prec_b, axis=0)
|
||||||
|
rec_macro[mi, ti, :, :] = _nanmean_np(rec_b, axis=0)
|
||||||
|
|
||||||
|
# weighted mean over bins (weights along bin axis)
|
||||||
|
w3 = w.reshape(-1, 1, 1)
|
||||||
|
prec_weighted[mi, ti, :, :] = _weighted_mean_np(prec_b, w3, axis=0)
|
||||||
|
rec_weighted[mi, ti, :, :] = _weighted_mean_np(rec_b, w3, axis=0)
|
||||||
|
|
||||||
|
# Across-MC aggregation (mean/std), then emit long-format df keyed by
|
||||||
|
# (agg_type, horizon_tau, topk_percent, cause_id)
|
||||||
|
rows: Dict[str, List[np.ndarray]] = {
|
||||||
|
"agg_type": [],
|
||||||
|
"horizon_tau": [],
|
||||||
|
"topk_percent": [],
|
||||||
|
"cause_id": [],
|
||||||
|
"n_mc": [],
|
||||||
|
"n_bins_used_mean": [],
|
||||||
|
"n_samples_total_mean": [],
|
||||||
|
"n_positives_total_mean": [],
|
||||||
|
"auc_mean": [],
|
||||||
|
"auc_std": [],
|
||||||
|
"auprc_mean": [],
|
||||||
|
"auprc_std": [],
|
||||||
|
"recall_at_K_mean": [],
|
||||||
|
"recall_at_K_std": [],
|
||||||
|
"precision_at_K_mean": [],
|
||||||
|
"precision_at_K_std": [],
|
||||||
|
"brier_score_mean": [],
|
||||||
|
"brier_score_std": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
cause_long = np.repeat(cause_id, P)
|
||||||
|
topk_long = np.tile(topk_percents.astype(float), K)
|
||||||
|
n_mc_val = float(M)
|
||||||
|
|
||||||
|
for ti, tau in enumerate(tau_vals):
|
||||||
|
# scalar totals (repeat across causes/topk)
|
||||||
|
n_bins_mean = float(
|
||||||
|
np.mean(n_bins_used[:, ti])) if M > 0 else float("nan")
|
||||||
|
n_samp_mean = float(
|
||||||
|
np.mean(n_samples_total[:, ti])) if M > 0 else float("nan")
|
||||||
|
n_pos_mean = _nanmean_np(n_pos_total[:, ti, :], axis=0) # (K,)
|
||||||
|
|
||||||
|
for agg_type in ("macro", "weighted"):
|
||||||
|
if agg_type == "macro":
|
||||||
|
auc_m = _nanmean_np(auc_macro[:, ti, :], axis=0)
|
||||||
|
auc_s = _nanstd_np_ddof1(auc_macro[:, ti, :], axis=0)
|
||||||
|
ap_m = _nanmean_np(ap_macro[:, ti, :], axis=0)
|
||||||
|
ap_s = _nanstd_np_ddof1(ap_macro[:, ti, :], axis=0)
|
||||||
|
brier_m = _nanmean_np(brier_macro[:, ti, :], axis=0)
|
||||||
|
brier_s = _nanstd_np_ddof1(brier_macro[:, ti, :], axis=0)
|
||||||
|
prec_m = _nanmean_np(prec_macro[:, ti, :, :], axis=0) # (P,K)
|
||||||
|
prec_s = _nanstd_np_ddof1(prec_macro[:, ti, :, :], axis=0)
|
||||||
|
rec_m = _nanmean_np(rec_macro[:, ti, :, :], axis=0)
|
||||||
|
rec_s = _nanstd_np_ddof1(rec_macro[:, ti, :, :], axis=0)
|
||||||
|
else:
|
||||||
|
auc_m = _nanmean_np(auc_weighted[:, ti, :], axis=0)
|
||||||
|
auc_s = _nanstd_np_ddof1(auc_weighted[:, ti, :], axis=0)
|
||||||
|
ap_m = _nanmean_np(ap_weighted[:, ti, :], axis=0)
|
||||||
|
ap_s = _nanstd_np_ddof1(ap_weighted[:, ti, :], axis=0)
|
||||||
|
brier_m = _nanmean_np(brier_weighted[:, ti, :], axis=0)
|
||||||
|
brier_s = _nanstd_np_ddof1(brier_weighted[:, ti, :], axis=0)
|
||||||
|
prec_m = _nanmean_np(prec_weighted[:, ti, :, :], axis=0)
|
||||||
|
prec_s = _nanstd_np_ddof1(prec_weighted[:, ti, :, :], axis=0)
|
||||||
|
rec_m = _nanmean_np(rec_weighted[:, ti, :, :], axis=0)
|
||||||
|
rec_s = _nanstd_np_ddof1(rec_weighted[:, ti, :, :], axis=0)
|
||||||
|
|
||||||
|
n_rows = K * P
|
||||||
|
rows["agg_type"].append(np.full(n_rows, agg_type, dtype=object))
|
||||||
|
rows["horizon_tau"].append(np.full(n_rows, float(tau), dtype=float))
|
||||||
|
rows["topk_percent"].append(topk_long)
|
||||||
|
rows["cause_id"].append(cause_long)
|
||||||
|
rows["n_mc"].append(np.full(n_rows, n_mc_val, dtype=float))
|
||||||
|
rows["n_bins_used_mean"].append(
|
||||||
|
np.full(n_rows, n_bins_mean, dtype=float))
|
||||||
|
rows["n_samples_total_mean"].append(
|
||||||
|
np.full(n_rows, n_samp_mean, dtype=float))
|
||||||
|
rows["n_positives_total_mean"].append(np.repeat(n_pos_mean, P))
|
||||||
|
|
||||||
|
rows["auc_mean"].append(np.repeat(auc_m, P))
|
||||||
|
rows["auc_std"].append(np.repeat(auc_s, P))
|
||||||
|
rows["auprc_mean"].append(np.repeat(ap_m, P))
|
||||||
|
rows["auprc_std"].append(np.repeat(ap_s, P))
|
||||||
|
rows["brier_score_mean"].append(np.repeat(brier_m, P))
|
||||||
|
rows["brier_score_std"].append(np.repeat(brier_s, P))
|
||||||
|
|
||||||
|
rows["precision_at_K_mean"].append(prec_m.T.reshape(-1))
|
||||||
|
rows["precision_at_K_std"].append(prec_s.T.reshape(-1))
|
||||||
|
rows["recall_at_K_mean"].append(rec_m.T.reshape(-1))
|
||||||
|
rows["recall_at_K_std"].append(rec_s.T.reshape(-1))
|
||||||
|
|
||||||
|
out = {k: np.concatenate(v, axis=0) for k, v in rows.items()}
|
||||||
|
df = pd.DataFrame(out)
|
||||||
|
return df.sort_values(
|
||||||
|
["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
|
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
|
||||||
"""Aggregate per-bin age evaluation results.
|
"""Aggregate per-bin age evaluation results.
|
||||||
|
|
||||||
@@ -129,34 +487,79 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Kept for backward compatibility (e.g., if callers load a CSV and need to
|
||||||
|
# aggregate). Prefer `aggregate_metrics_columnar` during evaluation.
|
||||||
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
|
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
|
||||||
|
|
||||||
gb = df_by_bin.groupby(group_keys)
|
df = df_by_bin[df_by_bin["n_samples"] > 0].copy()
|
||||||
|
if len(df) == 0:
|
||||||
|
return pd.DataFrame(
|
||||||
|
columns=[
|
||||||
|
"agg_type",
|
||||||
|
"horizon_tau",
|
||||||
|
"topk_percent",
|
||||||
|
"cause_id",
|
||||||
|
"n_mc",
|
||||||
|
"n_bins_used_mean",
|
||||||
|
"n_samples_total_mean",
|
||||||
|
"n_positives_total_mean",
|
||||||
|
"auc_mean",
|
||||||
|
"auc_std",
|
||||||
|
"auprc_mean",
|
||||||
|
"auprc_std",
|
||||||
|
"recall_at_K_mean",
|
||||||
|
"recall_at_K_std",
|
||||||
|
"precision_at_K_mean",
|
||||||
|
"precision_at_K_std",
|
||||||
|
"brier_score_mean",
|
||||||
|
"brier_score_std",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
# Macro: mean over bins
|
||||||
df_mc_macro = gb.apply(
|
df_mc_macro = (
|
||||||
lambda g: _bin_aggregate(g, weighted=False), include_groups=False
|
df.groupby(group_keys, as_index=False)
|
||||||
).reset_index()
|
.agg(
|
||||||
except TypeError: # pandas<2.2 (no include_groups)
|
n_bins_used=("age_bin_id", "nunique"),
|
||||||
df_mc_macro = gb.apply(lambda g: _bin_aggregate(
|
n_samples_total=("n_samples", "sum"),
|
||||||
g, weighted=False)).reset_index()
|
n_positives_total=("n_positives", "sum"),
|
||||||
|
auc=("auc", "mean"),
|
||||||
|
auprc=("auprc", "mean"),
|
||||||
|
recall_at_K=("recall_at_K", "mean"),
|
||||||
|
precision_at_K=("precision_at_K", "mean"),
|
||||||
|
brier_score=("brier_score", "mean"),
|
||||||
|
)
|
||||||
|
)
|
||||||
df_mc_macro["agg_type"] = "macro"
|
df_mc_macro["agg_type"] = "macro"
|
||||||
|
|
||||||
try:
|
# Weighted: weighted mean over bins with weights=n_samples, NaN-aware per metric
|
||||||
df_mc_weighted = gb.apply(
|
w = df["n_samples"].astype(float)
|
||||||
lambda g: _bin_aggregate(g, weighted=True), include_groups=False
|
df_w = df.copy()
|
||||||
).reset_index()
|
for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]:
|
||||||
except TypeError: # pandas<2.2 (no include_groups)
|
m = df_w[col].astype(float)
|
||||||
df_mc_weighted = gb.apply(
|
ww = w.where(m.notna(), other=0.0)
|
||||||
lambda g: _bin_aggregate(g, weighted=True)).reset_index()
|
df_w[f"__num_{col}"] = (m.fillna(0.0) * w)
|
||||||
df_mc_weighted["agg_type"] = "weighted"
|
df_w[f"__den_{col}"] = ww
|
||||||
|
|
||||||
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
|
df_mc_w = df_w.groupby(group_keys, as_index=False).agg(
|
||||||
|
n_bins_used=("age_bin_id", "nunique"),
|
||||||
|
n_samples_total=("n_samples", "sum"),
|
||||||
|
n_positives_total=("n_positives", "sum"),
|
||||||
|
**{f"__num_{c}": (f"__num_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]},
|
||||||
|
**{f"__den_{c}": (f"__den_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]},
|
||||||
|
)
|
||||||
|
for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]:
|
||||||
|
num = df_mc_w[f"__num_{col}"].astype(float)
|
||||||
|
den = df_mc_w[f"__den_{col}"].astype(float)
|
||||||
|
df_mc_w[col] = (num / den).where(den > 0.0, other=float("nan"))
|
||||||
|
df_mc_w.drop(columns=[f"__num_{col}", f"__den_{col}"], inplace=True)
|
||||||
|
df_mc_w["agg_type"] = "weighted"
|
||||||
|
|
||||||
|
df_mc_binagg = pd.concat([df_mc_macro, df_mc_w], ignore_index=True)
|
||||||
|
|
||||||
df_agg = (
|
df_agg = (
|
||||||
df_mc_binagg.groupby(
|
df_mc_binagg.groupby(
|
||||||
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
|
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False)
|
||||||
)
|
|
||||||
.agg(
|
.agg(
|
||||||
n_mc=("mc_idx", "nunique"),
|
n_mc=("mc_idx", "nunique"),
|
||||||
n_bins_used_mean=("n_bins_used", "mean"),
|
n_bins_used_mean=("n_bins_used", "mean"),
|
||||||
@@ -173,10 +576,7 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
|
|||||||
brier_score_mean=("brier_score", "mean"),
|
brier_score_mean=("brier_score", "mean"),
|
||||||
brier_score_std=("brier_score", "std"),
|
brier_score_std=("brier_score", "std"),
|
||||||
)
|
)
|
||||||
.sort_values(
|
.sort_values(["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True)
|
||||||
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
|
|
||||||
ignore_index=True,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return df_agg
|
return df_agg
|
||||||
|
|
||||||
@@ -193,6 +593,7 @@ class EvalAgeConfig:
|
|||||||
n_mc: int = 5
|
n_mc: int = 5
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
cause_ids: Optional[Sequence[int]] = None
|
cause_ids: Optional[Sequence[int]] = None
|
||||||
|
store_per_cause: bool = True
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@@ -247,12 +648,17 @@ def evaluate_time_dependent_age_bins(
|
|||||||
if cfg.cause_ids is None:
|
if cfg.cause_ids is None:
|
||||||
cause_ids = None
|
cause_ids = None
|
||||||
n_causes_eval = int(n_disease)
|
n_causes_eval = int(n_disease)
|
||||||
|
cause_id_vec = np.arange(n_causes_eval, dtype=int)
|
||||||
else:
|
else:
|
||||||
cause_ids = torch.tensor(
|
cause_ids = torch.tensor(
|
||||||
list(cfg.cause_ids), dtype=torch.long, device=device)
|
list(cfg.cause_ids), dtype=torch.long, device=device)
|
||||||
n_causes_eval = int(cause_ids.numel())
|
n_causes_eval = int(cause_ids.numel())
|
||||||
|
cause_id_vec = np.asarray(list(cfg.cause_ids), dtype=int)
|
||||||
|
|
||||||
rows_by_bin: List[Dict[str, float | int]] = []
|
topk_percents_np = np.asarray(topk_percents, dtype=float)
|
||||||
|
|
||||||
|
# Columnar per-(mc,tau,bin) blocks; avoids Python per-row dict appends.
|
||||||
|
blocks: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
for mc_idx in range(int(cfg.n_mc)):
|
for mc_idx in range(int(cfg.n_mc)):
|
||||||
global_mc_idx = int(mc_offset) + int(mc_idx)
|
global_mc_idx = int(mc_offset) + int(mc_idx)
|
||||||
@@ -354,27 +760,25 @@ def evaluate_time_dependent_age_bins(
|
|||||||
for h_idx, tau_y in enumerate(horizons_years):
|
for h_idx, tau_y in enumerate(horizons_years):
|
||||||
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
||||||
if len(y_true_mc[h_idx][bin_idx]) == 0:
|
if len(y_true_mc[h_idx][bin_idx]) == 0:
|
||||||
# No samples in this bin for this (mc, tau)
|
# No samples in this bin for this (mc, tau): store a single
|
||||||
for cause_k in range(n_causes_eval):
|
# block with NaN metric vectors.
|
||||||
cause_id = int(cause_k) if cause_ids is None else int(
|
K = int(n_causes_eval)
|
||||||
cfg.cause_ids[cause_k])
|
P = int(topk_percents_np.size)
|
||||||
for k_percent in topk_percents:
|
blocks.append(
|
||||||
rows_by_bin.append(
|
|
||||||
dict(
|
dict(
|
||||||
mc_idx=global_mc_idx,
|
mc_idx=global_mc_idx,
|
||||||
age_bin_id=bin_idx,
|
age_bin_id=bin_idx,
|
||||||
age_bin_low=float(a_lo),
|
age_bin_low=float(a_lo),
|
||||||
age_bin_high=float(a_hi),
|
age_bin_high=float(a_hi),
|
||||||
horizon_tau=float(tau_y),
|
horizon_tau=float(tau_y),
|
||||||
topk_percent=float(k_percent),
|
|
||||||
cause_id=cause_id,
|
|
||||||
n_samples=0,
|
n_samples=0,
|
||||||
n_positives=0,
|
cause_id=cause_id_vec,
|
||||||
auc=float("nan"),
|
n_positives=np.zeros((K,), dtype=int),
|
||||||
auprc=float("nan"),
|
auc=np.full((K,), np.nan, dtype=float),
|
||||||
recall_at_K=float("nan"),
|
auprc=np.full((K,), np.nan, dtype=float),
|
||||||
precision_at_K=float("nan"),
|
brier_score=np.full((K,), np.nan, dtype=float),
|
||||||
brier_score=float("nan"),
|
precision_at_K=np.full((P, K), np.nan, dtype=float),
|
||||||
|
recall_at_K=np.full((P, K), np.nan, dtype=float),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
@@ -397,40 +801,52 @@ def evaluate_time_dependent_age_bins(
|
|||||||
compute_ici=False,
|
compute_ici=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move just the metric vectors to CPU once per (mc, tau, bin)
|
# Collect a single columnar block (vectors, not per-row dicts).
|
||||||
# for DataFrame construction.
|
blocks.append(
|
||||||
auc = metrics.auc_per_cause.detach().cpu().numpy()
|
|
||||||
auprc = metrics.ap_per_cause.detach().cpu().numpy()
|
|
||||||
brier = metrics.brier_per_cause.detach().cpu().numpy()
|
|
||||||
n_pos = metrics.n_pos_per_cause.detach().cpu().numpy()
|
|
||||||
prec_at_k = metrics.precision_at_k.detach().cpu().numpy() # (P,K)
|
|
||||||
rec_at_k = metrics.recall_at_k.detach().cpu().numpy() # (P,K)
|
|
||||||
|
|
||||||
for cause_k in range(n_causes_eval):
|
|
||||||
cause_id = int(cause_k) if cause_ids is None else int(
|
|
||||||
cfg.cause_ids[cause_k])
|
|
||||||
for p_idx, k_percent in enumerate(topk_percents):
|
|
||||||
rows_by_bin.append(
|
|
||||||
dict(
|
dict(
|
||||||
mc_idx=global_mc_idx,
|
mc_idx=global_mc_idx,
|
||||||
age_bin_id=bin_idx,
|
age_bin_id=bin_idx,
|
||||||
age_bin_low=float(a_lo),
|
age_bin_low=float(a_lo),
|
||||||
age_bin_high=float(a_hi),
|
age_bin_high=float(a_hi),
|
||||||
horizon_tau=float(tau_y),
|
horizon_tau=float(tau_y),
|
||||||
topk_percent=float(k_percent),
|
n_samples=int(n_samples),
|
||||||
cause_id=cause_id,
|
cause_id=cause_id_vec,
|
||||||
n_samples=n_samples,
|
n_positives=metrics.n_pos_per_cause.detach().cpu().numpy().astype(int),
|
||||||
n_positives=int(n_pos[cause_k]),
|
auc=metrics.auc_per_cause.detach().cpu().numpy().astype(float),
|
||||||
auc=float(auc[cause_k]),
|
auprc=metrics.ap_per_cause.detach().cpu().numpy().astype(float),
|
||||||
auprc=float(auprc[cause_k]),
|
brier_score=metrics.brier_per_cause.detach().cpu().numpy().astype(float),
|
||||||
recall_at_K=float(rec_at_k[p_idx, cause_k]),
|
precision_at_K=metrics.precision_at_k.detach().cpu().numpy().astype(float),
|
||||||
precision_at_K=float(prec_at_k[p_idx, cause_k]),
|
recall_at_K=metrics.recall_at_k.detach().cpu().numpy().astype(float),
|
||||||
brier_score=float(brier[cause_k]),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
df_by_bin = pd.DataFrame(rows_by_bin)
|
# Aggregation is computed from columnar blocks (fast, no pandas apply).
|
||||||
|
df_agg = aggregate_metrics_columnar(
|
||||||
|
blocks,
|
||||||
|
topk_percents=topk_percents_np,
|
||||||
|
cause_id=cause_id_vec,
|
||||||
|
)
|
||||||
|
|
||||||
df_agg = aggregate_age_bin_results(df_by_bin)
|
if bool(cfg.store_per_cause):
|
||||||
|
df_by_bin = _blocks_to_df_by_bin(blocks, topk_percents=topk_percents_np)
|
||||||
|
else:
|
||||||
|
df_by_bin = pd.DataFrame(
|
||||||
|
columns=[
|
||||||
|
"mc_idx",
|
||||||
|
"age_bin_id",
|
||||||
|
"age_bin_low",
|
||||||
|
"age_bin_high",
|
||||||
|
"horizon_tau",
|
||||||
|
"topk_percent",
|
||||||
|
"cause_id",
|
||||||
|
"n_samples",
|
||||||
|
"n_positives",
|
||||||
|
"auc",
|
||||||
|
"auprc",
|
||||||
|
"recall_at_K",
|
||||||
|
"precision_at_K",
|
||||||
|
"brier_score",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return df_by_bin, df_agg
|
return df_by_bin, df_agg
|
||||||
|
|||||||
Reference in New Issue
Block a user