2026-01-16 16:13:31 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from typing import Dict, List, Optional, Sequence, Tuple
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import torch
|
|
|
|
|
|
2026-01-16 16:57:35 +08:00
|
|
|
try:
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
except Exception: # pragma: no cover
|
2026-01-16 16:13:31 +08:00
|
|
|
|
2026-01-16 16:57:35 +08:00
|
|
|
def tqdm(x, **kwargs):
|
|
|
|
|
return x
|
2026-01-16 16:13:31 +08:00
|
|
|
from utils import (
|
|
|
|
|
multi_hot_ever_within_horizon,
|
|
|
|
|
multi_hot_selected_causes_within_horizon,
|
|
|
|
|
sample_context_in_fixed_age_bin,
|
|
|
|
|
)
|
|
|
|
|
|
2026-01-16 17:19:27 +08:00
|
|
|
from torch_metrics import compute_binary_metrics_torch
|
|
|
|
|
|
2026-01-16 16:13:31 +08:00
|
|
|
|
2026-01-16 16:27:02 +08:00
|
|
|
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
|
|
|
|
|
"""Aggregate per-bin age evaluation results.
|
|
|
|
|
|
|
|
|
|
Produces both:
|
|
|
|
|
- macro: unweighted mean over bins with n_samples>0
|
|
|
|
|
- weighted: weighted mean over bins using weights=n_samples
|
|
|
|
|
|
|
|
|
|
Then aggregates across MC repetitions (mean/std).
|
|
|
|
|
|
|
|
|
|
Requires df_by_bin to include:
|
|
|
|
|
mc_idx, horizon_tau, topk_percent, cause_id, age_bin_id,
|
|
|
|
|
n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
DataFrame keyed by (agg_type, horizon_tau, topk_percent, cause_id)
|
|
|
|
|
"""
|
|
|
|
|
if df_by_bin is None or len(df_by_bin) == 0:
|
|
|
|
|
return pd.DataFrame(
|
|
|
|
|
columns=[
|
|
|
|
|
"agg_type",
|
|
|
|
|
"horizon_tau",
|
|
|
|
|
"topk_percent",
|
|
|
|
|
"cause_id",
|
|
|
|
|
"n_mc",
|
|
|
|
|
"n_bins_used_mean",
|
|
|
|
|
"n_samples_total_mean",
|
|
|
|
|
"n_positives_total_mean",
|
|
|
|
|
"auc_mean",
|
|
|
|
|
"auc_std",
|
|
|
|
|
"auprc_mean",
|
|
|
|
|
"auprc_std",
|
|
|
|
|
"recall_at_K_mean",
|
|
|
|
|
"recall_at_K_std",
|
|
|
|
|
"precision_at_K_mean",
|
|
|
|
|
"precision_at_K_std",
|
|
|
|
|
"brier_score_mean",
|
|
|
|
|
"brier_score_std",
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
|
|
|
|
|
g = group[group["n_samples"] > 0]
|
|
|
|
|
if len(g) == 0:
|
|
|
|
|
return pd.Series(
|
|
|
|
|
dict(
|
|
|
|
|
n_bins_used=0,
|
|
|
|
|
n_samples_total=0,
|
|
|
|
|
n_positives_total=0,
|
|
|
|
|
auc=float("nan"),
|
|
|
|
|
auprc=float("nan"),
|
|
|
|
|
recall_at_K=float("nan"),
|
|
|
|
|
precision_at_K=float("nan"),
|
|
|
|
|
brier_score=float("nan"),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
n_bins_used = int(g["age_bin_id"].nunique())
|
|
|
|
|
n_samples_total = int(g["n_samples"].sum())
|
|
|
|
|
n_positives_total = int(g["n_positives"].sum())
|
|
|
|
|
|
|
|
|
|
if not weighted:
|
|
|
|
|
return pd.Series(
|
|
|
|
|
dict(
|
|
|
|
|
n_bins_used=n_bins_used,
|
|
|
|
|
n_samples_total=n_samples_total,
|
|
|
|
|
n_positives_total=n_positives_total,
|
|
|
|
|
auc=float(g["auc"].mean()),
|
|
|
|
|
auprc=float(g["auprc"].mean()),
|
|
|
|
|
recall_at_K=float(g["recall_at_K"].mean()),
|
|
|
|
|
precision_at_K=float(g["precision_at_K"].mean()),
|
|
|
|
|
brier_score=float(g["brier_score"].mean()),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
w = g["n_samples"].to_numpy(dtype=float)
|
|
|
|
|
w_sum = float(w.sum())
|
|
|
|
|
if w_sum <= 0.0:
|
|
|
|
|
return pd.Series(
|
|
|
|
|
dict(
|
|
|
|
|
n_bins_used=n_bins_used,
|
|
|
|
|
n_samples_total=n_samples_total,
|
|
|
|
|
n_positives_total=n_positives_total,
|
|
|
|
|
auc=float("nan"),
|
|
|
|
|
auprc=float("nan"),
|
|
|
|
|
recall_at_K=float("nan"),
|
|
|
|
|
precision_at_K=float("nan"),
|
|
|
|
|
brier_score=float("nan"),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _wavg(col: str) -> float:
|
|
|
|
|
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
|
|
|
|
|
|
|
|
|
|
return pd.Series(
|
|
|
|
|
dict(
|
|
|
|
|
n_bins_used=n_bins_used,
|
|
|
|
|
n_samples_total=n_samples_total,
|
|
|
|
|
n_positives_total=n_positives_total,
|
|
|
|
|
auc=_wavg("auc"),
|
|
|
|
|
auprc=_wavg("auprc"),
|
|
|
|
|
recall_at_K=_wavg("recall_at_K"),
|
|
|
|
|
precision_at_K=_wavg("precision_at_K"),
|
|
|
|
|
brier_score=_wavg("brier_score"),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
|
|
|
|
|
|
|
|
|
|
df_mc_macro = (
|
|
|
|
|
df_by_bin.groupby(group_keys)
|
|
|
|
|
.apply(lambda g: _bin_aggregate(g, weighted=False))
|
|
|
|
|
.reset_index()
|
|
|
|
|
)
|
|
|
|
|
df_mc_macro["agg_type"] = "macro"
|
|
|
|
|
|
|
|
|
|
df_mc_weighted = (
|
|
|
|
|
df_by_bin.groupby(group_keys)
|
|
|
|
|
.apply(lambda g: _bin_aggregate(g, weighted=True))
|
|
|
|
|
.reset_index()
|
|
|
|
|
)
|
|
|
|
|
df_mc_weighted["agg_type"] = "weighted"
|
|
|
|
|
|
|
|
|
|
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
|
|
|
|
|
|
|
|
|
|
df_agg = (
|
|
|
|
|
df_mc_binagg.groupby(
|
|
|
|
|
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
|
|
|
|
|
)
|
|
|
|
|
.agg(
|
|
|
|
|
n_mc=("mc_idx", "nunique"),
|
|
|
|
|
n_bins_used_mean=("n_bins_used", "mean"),
|
|
|
|
|
n_samples_total_mean=("n_samples_total", "mean"),
|
|
|
|
|
n_positives_total_mean=("n_positives_total", "mean"),
|
|
|
|
|
auc_mean=("auc", "mean"),
|
|
|
|
|
auc_std=("auc", "std"),
|
|
|
|
|
auprc_mean=("auprc", "mean"),
|
|
|
|
|
auprc_std=("auprc", "std"),
|
|
|
|
|
recall_at_K_mean=("recall_at_K", "mean"),
|
|
|
|
|
recall_at_K_std=("recall_at_K", "std"),
|
|
|
|
|
precision_at_K_mean=("precision_at_K", "mean"),
|
|
|
|
|
precision_at_K_std=("precision_at_K", "std"),
|
|
|
|
|
brier_score_mean=("brier_score", "mean"),
|
|
|
|
|
brier_score_std=("brier_score", "std"),
|
|
|
|
|
)
|
|
|
|
|
.sort_values(
|
|
|
|
|
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
|
|
|
|
|
ignore_index=True,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return df_agg
|
|
|
|
|
|
|
|
|
|
|
2026-01-16 17:19:27 +08:00
|
|
|
# NOTE: metric computation is torch/GPU-native in `torch_metrics.py`.
|
|
|
|
|
# NumPy/Pandas are only used for final CSV formatting/aggregation.
|
2026-01-16 16:13:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class EvalAgeConfig:
|
|
|
|
|
horizons_years: Sequence[float]
|
|
|
|
|
age_bins: Sequence[Tuple[float, float]]
|
|
|
|
|
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
|
|
|
|
|
n_mc: int = 5
|
|
|
|
|
seed: int = 0
|
|
|
|
|
cause_ids: Optional[Sequence[int]] = None
|
|
|
|
|
|
|
|
|
|
|
2026-01-16 16:57:35 +08:00
|
|
|
@torch.inference_mode()
|
2026-01-16 16:13:31 +08:00
|
|
|
def evaluate_time_dependent_age_bins(
|
|
|
|
|
model: torch.nn.Module,
|
|
|
|
|
head: torch.nn.Module,
|
|
|
|
|
criterion,
|
|
|
|
|
dataloader: torch.utils.data.DataLoader,
|
|
|
|
|
n_disease: int,
|
|
|
|
|
cfg: EvalAgeConfig,
|
|
|
|
|
device: str | torch.device,
|
2026-01-16 16:27:02 +08:00
|
|
|
mc_offset: int = 0,
|
2026-01-16 16:13:31 +08:00
|
|
|
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
|
|
|
"""Delphi-2M-style age-bin evaluation with strict horizon alignment.
|
|
|
|
|
|
|
|
|
|
Semantics (strict): for each (MC, horizon tau, age bin) we independently:
|
|
|
|
|
- build the eligible token set within that bin
|
|
|
|
|
- enforce follow-up coverage: t_ctx + tau <= t_end
|
|
|
|
|
- randomly sample exactly one token per individual within the bin (de-dup)
|
|
|
|
|
- recompute context representations and predictions for that (tau, bin)
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id)
|
|
|
|
|
df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted}
|
|
|
|
|
"""
|
|
|
|
|
device = torch.device(device)
|
|
|
|
|
model.eval()
|
|
|
|
|
head.eval()
|
|
|
|
|
|
|
|
|
|
horizons_years = [float(x) for x in cfg.horizons_years]
|
|
|
|
|
if len(horizons_years) == 0:
|
|
|
|
|
raise ValueError("cfg.horizons_years must be non-empty")
|
|
|
|
|
|
|
|
|
|
age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins]
|
|
|
|
|
if len(age_bins) == 0:
|
|
|
|
|
raise ValueError("cfg.age_bins must be non-empty")
|
|
|
|
|
for (a, b) in age_bins:
|
|
|
|
|
if not (b > a):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"age_bins must be (low, high) with high>low; got {(a, b)}")
|
|
|
|
|
|
|
|
|
|
topk_percents = [float(x) for x in cfg.topk_percents]
|
|
|
|
|
if len(topk_percents) == 0:
|
|
|
|
|
raise ValueError("cfg.topk_percents must be non-empty")
|
|
|
|
|
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"All topk_percents must be in (0,100]; got {topk_percents}")
|
|
|
|
|
|
|
|
|
|
if int(cfg.n_mc) <= 0:
|
|
|
|
|
raise ValueError("cfg.n_mc must be >= 1")
|
|
|
|
|
|
|
|
|
|
if cfg.cause_ids is None:
|
|
|
|
|
cause_ids = None
|
|
|
|
|
n_causes_eval = int(n_disease)
|
|
|
|
|
else:
|
|
|
|
|
cause_ids = torch.tensor(
|
|
|
|
|
list(cfg.cause_ids), dtype=torch.long, device=device)
|
|
|
|
|
n_causes_eval = int(cause_ids.numel())
|
|
|
|
|
|
2026-01-16 17:19:27 +08:00
|
|
|
rows_by_bin: List[Dict[str, float | int]] = []
|
2026-01-16 16:13:31 +08:00
|
|
|
|
|
|
|
|
for mc_idx in range(int(cfg.n_mc)):
|
2026-01-16 16:27:02 +08:00
|
|
|
global_mc_idx = int(mc_offset) + int(mc_idx)
|
2026-01-16 17:19:27 +08:00
|
|
|
|
|
|
|
|
# Storage for this MC only: (tau, bin) -> list of GPU tensors.
|
|
|
|
|
# This keeps computations GPU-first while preventing a factor-n_mc
|
|
|
|
|
# blow-up in GPU memory.
|
|
|
|
|
y_true_mc: List[List[List[torch.Tensor]]] = [
|
|
|
|
|
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
|
|
|
|
|
]
|
|
|
|
|
y_pred_mc: List[List[List[torch.Tensor]]] = [
|
|
|
|
|
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
|
|
|
|
|
]
|
|
|
|
|
|
2026-01-16 16:13:31 +08:00
|
|
|
# tqdm over batches; include MC idx in description.
|
|
|
|
|
for batch_idx, batch in enumerate(
|
|
|
|
|
tqdm(dataloader,
|
|
|
|
|
desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch")
|
|
|
|
|
):
|
|
|
|
|
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
|
|
|
|
event_seq = event_seq.to(device)
|
|
|
|
|
time_seq = time_seq.to(device)
|
|
|
|
|
cont_feats = cont_feats.to(device)
|
|
|
|
|
cate_feats = cate_feats.to(device)
|
|
|
|
|
sexes = sexes.to(device)
|
|
|
|
|
|
|
|
|
|
B = int(event_seq.size(0))
|
|
|
|
|
b = torch.arange(B, device=device)
|
|
|
|
|
|
2026-01-16 16:57:35 +08:00
|
|
|
# Hoist backbone forward pass: inputs are identical across (tau, age_bin)
|
|
|
|
|
# within this batch, so this is safe and numerically identical.
|
|
|
|
|
h = model(event_seq, time_seq, sexes,
|
|
|
|
|
cont_feats, cate_feats) # (B,L,D)
|
|
|
|
|
|
2026-01-16 16:13:31 +08:00
|
|
|
for tau_idx, tau_y in enumerate(horizons_years):
|
2026-01-16 16:57:35 +08:00
|
|
|
tau_tensor = torch.tensor(float(tau_y), device=device)
|
2026-01-16 16:13:31 +08:00
|
|
|
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
|
|
|
|
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
|
|
|
|
|
seed = (
|
|
|
|
|
int(cfg.seed)
|
2026-01-16 16:27:02 +08:00
|
|
|
+ (100_000 * int(global_mc_idx))
|
2026-01-16 16:13:31 +08:00
|
|
|
+ (1_000 * int(tau_idx))
|
|
|
|
|
+ (10 * int(bin_idx))
|
|
|
|
|
+ int(batch_idx)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
keep, t_ctx = sample_context_in_fixed_age_bin(
|
|
|
|
|
event_seq=event_seq,
|
|
|
|
|
time_seq=time_seq,
|
|
|
|
|
tau_years=float(tau_y),
|
|
|
|
|
age_bin=(float(a_lo), float(a_hi)),
|
|
|
|
|
seed=seed,
|
|
|
|
|
)
|
|
|
|
|
if not keep.any():
|
|
|
|
|
continue
|
|
|
|
|
|
2026-01-16 16:57:35 +08:00
|
|
|
# Bin-specific prediction: context indices differ per (tau, bin)
|
|
|
|
|
# but the backbone features do not.
|
2026-01-16 16:13:31 +08:00
|
|
|
c = h[b, t_ctx]
|
|
|
|
|
logits = head(c)
|
|
|
|
|
|
|
|
|
|
cifs = criterion.calculate_cifs(
|
2026-01-16 16:57:35 +08:00
|
|
|
logits, taus=tau_tensor
|
2026-01-16 16:13:31 +08:00
|
|
|
)
|
|
|
|
|
if cifs.ndim != 2:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"criterion.calculate_cifs must return (B,K) for scalar tau; "
|
|
|
|
|
f"got shape={tuple(cifs.shape)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if cause_ids is None:
|
|
|
|
|
y = multi_hot_ever_within_horizon(
|
|
|
|
|
event_seq=event_seq,
|
|
|
|
|
time_seq=time_seq,
|
|
|
|
|
t_ctx=t_ctx,
|
|
|
|
|
tau_years=float(tau_y),
|
|
|
|
|
n_disease=n_disease,
|
|
|
|
|
)
|
|
|
|
|
preds = cifs
|
|
|
|
|
else:
|
|
|
|
|
y = multi_hot_selected_causes_within_horizon(
|
|
|
|
|
event_seq=event_seq,
|
|
|
|
|
time_seq=time_seq,
|
|
|
|
|
t_ctx=t_ctx,
|
|
|
|
|
tau_years=float(tau_y),
|
|
|
|
|
cause_ids=cause_ids,
|
|
|
|
|
n_disease=n_disease,
|
|
|
|
|
)
|
|
|
|
|
preds = cifs.index_select(dim=1, index=cause_ids)
|
|
|
|
|
|
2026-01-16 17:19:27 +08:00
|
|
|
y_true_mc[tau_idx][bin_idx].append(
|
|
|
|
|
y[keep].detach().to(dtype=torch.bool)
|
2026-01-16 16:13:31 +08:00
|
|
|
)
|
2026-01-16 17:19:27 +08:00
|
|
|
y_pred_mc[tau_idx][bin_idx].append(
|
|
|
|
|
preds[keep].detach().to(dtype=torch.float32)
|
2026-01-16 16:13:31 +08:00
|
|
|
)
|
|
|
|
|
|
2026-01-16 17:19:27 +08:00
|
|
|
# Aggregate this MC immediately (frees GPU memory before next MC).
|
2026-01-16 16:13:31 +08:00
|
|
|
for h_idx, tau_y in enumerate(horizons_years):
|
|
|
|
|
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
2026-01-16 17:19:27 +08:00
|
|
|
if len(y_true_mc[h_idx][bin_idx]) == 0:
|
2026-01-16 16:13:31 +08:00
|
|
|
# No samples in this bin for this (mc, tau)
|
|
|
|
|
for cause_k in range(n_causes_eval):
|
|
|
|
|
cause_id = int(cause_k) if cause_ids is None else int(
|
|
|
|
|
cfg.cause_ids[cause_k])
|
|
|
|
|
for k_percent in topk_percents:
|
|
|
|
|
rows_by_bin.append(
|
|
|
|
|
dict(
|
2026-01-16 16:27:02 +08:00
|
|
|
mc_idx=global_mc_idx,
|
2026-01-16 16:13:31 +08:00
|
|
|
age_bin_id=bin_idx,
|
|
|
|
|
age_bin_low=float(a_lo),
|
|
|
|
|
age_bin_high=float(a_hi),
|
|
|
|
|
horizon_tau=float(tau_y),
|
|
|
|
|
topk_percent=float(k_percent),
|
|
|
|
|
cause_id=cause_id,
|
|
|
|
|
n_samples=0,
|
|
|
|
|
n_positives=0,
|
|
|
|
|
auc=float("nan"),
|
|
|
|
|
auprc=float("nan"),
|
|
|
|
|
recall_at_K=float("nan"),
|
|
|
|
|
precision_at_K=float("nan"),
|
|
|
|
|
brier_score=float("nan"),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
continue
|
|
|
|
|
|
2026-01-16 17:19:27 +08:00
|
|
|
yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0)
|
|
|
|
|
pb_t = torch.cat(y_pred_mc[h_idx][bin_idx], dim=0)
|
2026-01-16 16:57:35 +08:00
|
|
|
if tuple(yb_t.shape) != tuple(pb_t.shape):
|
2026-01-16 16:13:31 +08:00
|
|
|
raise ValueError(
|
2026-01-16 16:57:35 +08:00
|
|
|
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}"
|
2026-01-16 16:13:31 +08:00
|
|
|
)
|
|
|
|
|
|
2026-01-16 17:19:27 +08:00
|
|
|
n_samples = int(yb_t.size(0))
|
2026-01-16 16:57:35 +08:00
|
|
|
|
2026-01-16 17:19:27 +08:00
|
|
|
metrics = compute_binary_metrics_torch(
|
|
|
|
|
y_true=yb_t,
|
|
|
|
|
y_pred=pb_t,
|
|
|
|
|
k_percents=topk_percents,
|
|
|
|
|
tie_mode="exact",
|
|
|
|
|
chunk_size=128,
|
|
|
|
|
compute_ici=False,
|
|
|
|
|
)
|
2026-01-16 16:13:31 +08:00
|
|
|
|
2026-01-16 17:19:27 +08:00
|
|
|
# Move just the metric vectors to CPU once per (mc, tau, bin)
|
|
|
|
|
# for DataFrame construction.
|
|
|
|
|
auc = metrics.auc_per_cause.detach().cpu().numpy()
|
|
|
|
|
auprc = metrics.ap_per_cause.detach().cpu().numpy()
|
|
|
|
|
brier = metrics.brier_per_cause.detach().cpu().numpy()
|
|
|
|
|
n_pos = metrics.n_pos_per_cause.detach().cpu().numpy()
|
|
|
|
|
prec_at_k = metrics.precision_at_k.detach().cpu().numpy() # (P,K)
|
|
|
|
|
rec_at_k = metrics.recall_at_k.detach().cpu().numpy() # (P,K)
|
2026-01-16 16:13:31 +08:00
|
|
|
|
2026-01-16 17:19:27 +08:00
|
|
|
for cause_k in range(n_causes_eval):
|
2026-01-16 16:13:31 +08:00
|
|
|
cause_id = int(cause_k) if cause_ids is None else int(
|
|
|
|
|
cfg.cause_ids[cause_k])
|
2026-01-16 17:19:27 +08:00
|
|
|
for p_idx, k_percent in enumerate(topk_percents):
|
2026-01-16 16:13:31 +08:00
|
|
|
rows_by_bin.append(
|
|
|
|
|
dict(
|
2026-01-16 16:27:02 +08:00
|
|
|
mc_idx=global_mc_idx,
|
2026-01-16 16:13:31 +08:00
|
|
|
age_bin_id=bin_idx,
|
|
|
|
|
age_bin_low=float(a_lo),
|
|
|
|
|
age_bin_high=float(a_hi),
|
|
|
|
|
horizon_tau=float(tau_y),
|
|
|
|
|
topk_percent=float(k_percent),
|
|
|
|
|
cause_id=cause_id,
|
|
|
|
|
n_samples=n_samples,
|
2026-01-16 17:19:27 +08:00
|
|
|
n_positives=int(n_pos[cause_k]),
|
|
|
|
|
auc=float(auc[cause_k]),
|
|
|
|
|
auprc=float(auprc[cause_k]),
|
|
|
|
|
recall_at_K=float(rec_at_k[p_idx, cause_k]),
|
|
|
|
|
precision_at_K=float(prec_at_k[p_idx, cause_k]),
|
|
|
|
|
brier_score=float(brier[cause_k]),
|
2026-01-16 16:13:31 +08:00
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
df_by_bin = pd.DataFrame(rows_by_bin)
|
|
|
|
|
|
2026-01-16 16:27:02 +08:00
|
|
|
df_agg = aggregate_age_bin_results(df_by_bin)
|
2026-01-16 16:13:31 +08:00
|
|
|
|
|
|
|
|
return df_by_bin, df_agg
|