Add binary metrics computation and refactor evaluation logic in age-bin evaluation

This commit is contained in:
2026-01-16 17:19:27 +08:00
parent b1647d1b74
commit 810c75e6d1
2 changed files with 573 additions and 145 deletions

View File

@@ -20,6 +20,8 @@ from utils import (
sample_context_in_fixed_age_bin,
)
from torch_metrics import compute_binary_metrics_torch
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
"""Aggregate per-bin age evaluation results.
@@ -173,105 +175,8 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
return df_agg
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Compute ROC AUC for binary labels with tie-aware ranking.
Returns NaN if y_true has no positives or no negatives.
Uses the MannWhitney U statistic with average ranks for ties.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
n_neg = n - n_pos
if n_pos == 0 or n_neg == 0:
return float("nan")
# Rank scores ascending, average ranks for ties.
order = np.argsort(y_score, kind="mergesort")
sorted_scores = y_score[order]
ranks = np.empty(n, dtype=float)
i = 0
# 1-based ranks
while i < n:
j = i + 1
while j < n and sorted_scores[j] == sorted_scores[i]:
j += 1
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
ranks[order[i:j]] = avg_rank
i = j
sum_ranks_pos = float(ranks[y_true].sum())
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
return float(u / (n_pos * n_neg))
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Average precision (area under PR curve using step-wise interpolation).
Returns NaN if no positives.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
if n_pos == 0:
return float("nan")
order = np.argsort(-y_score, kind="mergesort")
y = y_true[order]
tp = np.cumsum(y).astype(float)
fp = np.cumsum(~y).astype(float)
precision = tp / np.maximum(tp + fp, 1.0)
# AP = sum over each positive of precision at that point / n_pos
# (equivalent to ∑ Δrecall * precision)
ap = float(np.sum(precision[y]) / n_pos)
# handle potential tiny numerical overshoots
return float(max(0.0, min(1.0, ap)))
def _precision_recall_at_k_percent(
y_true: np.ndarray,
y_score: np.ndarray,
k_percent: float,
) -> Tuple[float, float]:
"""Precision@K% and Recall@K% for binary labels.
Returns (precision, recall). Returns NaN for recall if no positives.
Returns NaN for precision if k leads to 0 selected.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan"), float("nan")
n_pos = int(y_true.sum())
k = int(math.ceil((float(k_percent) / 100.0) * n))
if k <= 0:
return float("nan"), float("nan")
order = np.argsort(-y_score, kind="mergesort")
top = order[:k]
tp_top = int(y_true[top].sum())
precision = tp_top / k
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
return float(precision), float(recall)
# NOTE: metric computation is torch/GPU-native in `torch_metrics.py`.
# NumPy/Pandas are only used for final CSV formatting/aggregation.
@dataclass
@@ -341,18 +246,21 @@ def evaluate_time_dependent_age_bins(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
# Storage: (mc, h, bin) -> list of CPU tensors (avoid .numpy() in inner loops)
y_true: List[List[List[List[torch.Tensor]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
y_pred: List[List[List[List[torch.Tensor]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
rows_by_bin: List[Dict[str, float | int]] = []
for mc_idx in range(int(cfg.n_mc)):
global_mc_idx = int(mc_offset) + int(mc_idx)
# Storage for this MC only: (tau, bin) -> list of GPU tensors.
# This keeps computations GPU-first while preventing a factor-n_mc
# blow-up in GPU memory.
y_true_mc: List[List[List[torch.Tensor]]] = [
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
]
y_pred_mc: List[List[List[torch.Tensor]]] = [
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
]
# tqdm over batches; include MC idx in description.
for batch_idx, batch in enumerate(
tqdm(dataloader,
@@ -429,24 +337,17 @@ def evaluate_time_dependent_age_bins(
)
preds = cifs.index_select(dim=1, index=cause_ids)
# Reduce CPU/NumPy conversion overhead: keep as CPU torch tensors
# and convert to NumPy once during aggregation.
y_true[mc_idx][tau_idx][bin_idx].append(
y[keep].detach().to(dtype=torch.bool,
device="cpu", non_blocking=True)
y_true_mc[tau_idx][bin_idx].append(
y[keep].detach().to(dtype=torch.bool)
)
y_pred[mc_idx][tau_idx][bin_idx].append(
preds[keep].detach().to(dtype=torch.float32,
device="cpu", non_blocking=True)
y_pred_mc[tau_idx][bin_idx].append(
preds[keep].detach().to(dtype=torch.float32)
)
rows_by_bin: List[Dict[str, float | int]] = []
for mc_idx in range(int(cfg.n_mc)):
global_mc_idx = int(mc_offset) + int(mc_idx)
# Aggregate this MC immediately (frees GPU memory before next MC).
for h_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
if len(y_true[mc_idx][h_idx][bin_idx]) == 0:
if len(y_true_mc[h_idx][bin_idx]) == 0:
# No samples in this bin for this (mc, tau)
for cause_k in range(n_causes_eval):
cause_id = int(cause_k) if cause_ids is None else int(
@@ -472,34 +373,37 @@ def evaluate_time_dependent_age_bins(
)
continue
yb_t = torch.cat(y_true[mc_idx][h_idx][bin_idx], dim=0)
pb_t = torch.cat(y_pred[mc_idx][h_idx][bin_idx], dim=0)
yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0)
pb_t = torch.cat(y_pred_mc[h_idx][bin_idx], dim=0)
if tuple(yb_t.shape) != tuple(pb_t.shape):
raise ValueError(
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}"
)
yb = yb_t.numpy()
pb = pb_t.numpy()
n_samples = int(yb_t.size(0))
n_samples = int(yb.shape[0])
metrics = compute_binary_metrics_torch(
y_true=yb_t,
y_pred=pb_t,
k_percents=topk_percents,
tie_mode="exact",
chunk_size=128,
compute_ici=False,
)
# Move just the metric vectors to CPU once per (mc, tau, bin)
# for DataFrame construction.
auc = metrics.auc_per_cause.detach().cpu().numpy()
auprc = metrics.ap_per_cause.detach().cpu().numpy()
brier = metrics.brier_per_cause.detach().cpu().numpy()
n_pos = metrics.n_pos_per_cause.detach().cpu().numpy()
prec_at_k = metrics.precision_at_k.detach().cpu().numpy() # (P,K)
rec_at_k = metrics.recall_at_k.detach().cpu().numpy() # (P,K)
for cause_k in range(n_causes_eval):
yk = yb[:, cause_k]
pk = pb[:, cause_k]
n_pos = int(yk.sum())
auc = _binary_roc_auc(yk, pk)
auprc = _average_precision(yk, pk)
brier = float(np.mean(
(yk.astype(float) - pk.astype(float)) ** 2)) if n_samples > 0 else float("nan")
cause_id = int(cause_k) if cause_ids is None else int(
cfg.cause_ids[cause_k])
for k_percent in topk_percents:
precision_k, recall_k = _precision_recall_at_k_percent(
yk, pk, float(k_percent))
for p_idx, k_percent in enumerate(topk_percents):
rows_by_bin.append(
dict(
mc_idx=global_mc_idx,
@@ -510,12 +414,12 @@ def evaluate_time_dependent_age_bins(
topk_percent=float(k_percent),
cause_id=cause_id,
n_samples=n_samples,
n_positives=n_pos,
auc=float(auc),
auprc=float(auprc),
recall_at_K=float(recall_k),
precision_at_K=float(precision_k),
brier_score=float(brier),
n_positives=int(n_pos[cause_k]),
auc=float(auc[cause_k]),
auprc=float(auprc[cause_k]),
recall_at_K=float(rec_at_k[p_idx, cause_k]),
precision_at_K=float(prec_at_k[p_idx, cause_k]),
brier_score=float(brier[cause_k]),
)
)

524
torch_metrics.py Normal file
View File

@@ -0,0 +1,524 @@
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Dict, Iterable, List, Literal, Optional, Sequence, Tuple
import torch
TieMode = Literal["exact", "approx"]
def _stable_sort(
x: torch.Tensor,
*,
dim: int,
descending: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Stable torch sort when available.
Determinism notes:
- When `stable=True` is supported by the installed PyTorch, we request it.
- Otherwise we fall back to `torch.sort`. For identical inputs on the same
device/runtime, this is typically deterministic, but tie ordering is not
guaranteed to be stable across versions.
"""
try:
return torch.sort(x, dim=dim, descending=descending, stable=True)
except TypeError:
return torch.sort(x, dim=dim, descending=descending)
def _nanmean(x: torch.Tensor) -> torch.Tensor:
mask = torch.isfinite(x)
if not bool(mask.any()):
return torch.tensor(float("nan"), device=x.device, dtype=x.dtype)
return x[mask].mean()
def _nanweighted_mean(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
x = x.to(torch.float32)
w = w.to(torch.float32)
mask = torch.isfinite(x) & torch.isfinite(w) & (w > 0)
if not bool(mask.any()):
return torch.tensor(float("nan"), device=x.device, dtype=torch.float32)
ww = w[mask]
return (x[mask] * ww).sum() / ww.sum()
def _validate_binary_inputs(y_true: torch.Tensor, y_score: torch.Tensor) -> None:
if y_true.ndim != 2 or y_score.ndim != 2:
raise ValueError(
f"Expected y_true and y_score to be 2D (N,K); got {tuple(y_true.shape)} and {tuple(y_score.shape)}"
)
if tuple(y_true.shape) != tuple(y_score.shape):
raise ValueError(
f"Shape mismatch: y_true{tuple(y_true.shape)} vs y_score{tuple(y_score.shape)}"
)
def brier_per_cause(y_true: torch.Tensor, y_score: torch.Tensor) -> torch.Tensor:
"""Brier score per cause.
Args:
y_true: (N,K) bool/int tensor
y_score: (N,K) float tensor
Returns:
(K,) float32 tensor; NaN if N==0.
"""
_validate_binary_inputs(y_true, y_score)
if y_true.numel() == 0:
return torch.full((y_true.size(1),), float("nan"), device=y_true.device, dtype=torch.float32)
yt = y_true.to(torch.float32)
ys = y_score.to(torch.float32)
return ((ys - yt) ** 2).mean(dim=0)
def ici_per_cause_fixed_width(
y_true: torch.Tensor,
y_score: torch.Tensor,
*,
n_bins: int = 15,
chunk_size: int = 128,
) -> torch.Tensor:
"""Integrated Calibration Index (ICI) via fixed-width bins on [0,1].
ICI per cause = E[ |p_bin - y_bin| ] where bin stats are computed over
fixed-width probability bins.
This is deterministic and GPU-friendly (scatter_add based).
Returns:
(K,) float32 tensor; NaN when N==0.
"""
_validate_binary_inputs(y_true, y_score)
if int(n_bins) <= 1:
raise ValueError("n_bins must be >= 2")
device = y_true.device
N, K = y_true.shape
if N == 0:
return torch.full((K,), float("nan"), device=device, dtype=torch.float32)
yt = y_true.to(torch.float32)
ys = y_score.to(torch.float32).clamp(0.0, 1.0)
out = torch.full((K,), float("nan"), device=device, dtype=torch.float32)
for start in range(0, K, int(chunk_size)):
end = min(K, start + int(chunk_size))
ys_c = ys[:, start:end]
yt_c = yt[:, start:end]
# bin index in [0, n_bins-1]
bin_idx = torch.clamp(
(ys_c * float(n_bins)).to(torch.long), max=int(n_bins) - 1)
counts = torch.zeros((int(n_bins), end - start),
device=device, dtype=torch.float32)
pred_sums = torch.zeros_like(counts)
true_sums = torch.zeros_like(counts)
ones = torch.ones_like(ys_c, dtype=torch.float32)
counts.scatter_add_(0, bin_idx, ones)
pred_sums.scatter_add_(0, bin_idx, ys_c)
true_sums.scatter_add_(0, bin_idx, yt_c)
denom = counts.clamp(min=1.0)
pred_mean = pred_sums / denom
true_mean = true_sums / denom
abs_gap = (pred_mean - true_mean).abs()
# sample-weighted average of bin gap
total = counts.sum(dim=0).clamp(min=1.0)
ici = (abs_gap * counts).sum(dim=0) / total
# If a cause has no samples (shouldn't happen when N>0), mark NaN.
out[start:end] = torch.where(
total > 0, ici.to(torch.float32), out[start:end])
return out
def average_precision_per_cause(
y_true: torch.Tensor,
y_score: torch.Tensor,
*,
tie_mode: TieMode = "exact",
chunk_size: int = 128,
) -> torch.Tensor:
"""Average precision (AP) per cause.
Definition matches sklearn's `average_precision_score` (step-wise PR integral):
AP = \\sum_i (R_i - R_{i-1}) * P_i
where i iterates over unique score thresholds.
tie_mode:
- "exact": tie-invariant AP by grouping identical scores (recommended)
- "approx": mean precision at positive ranks; can differ under ties
Returns:
(K,) float32 tensor with NaN for causes with 0 positives.
"""
_validate_binary_inputs(y_true, y_score)
device = y_true.device
N, K = y_true.shape
if N == 0:
return torch.full((K,), float("nan"), device=device, dtype=torch.float32)
yt = y_true.to(torch.bool)
ys = y_score.to(torch.float32)
n_pos_all = yt.sum(dim=0).to(torch.float32)
out = torch.full((K,), float("nan"), device=device, dtype=torch.float32)
for start in range(0, K, int(chunk_size)):
end = min(K, start + int(chunk_size))
yt_c = yt[:, start:end]
ys_c = ys[:, start:end]
# For exact mode we need per-cause tie grouping; do per-cause loops
# within a chunk to keep memory bounded and stay on GPU.
for j in range(end - start):
n_pos = n_pos_all[start + j]
scores = ys_c[:, j]
labels = yt_c[:, j]
if tie_mode == "approx":
_, order = _stable_sort(scores, dim=0, descending=True)
y_sorted = labels.gather(0, order).to(torch.float32)
tp = y_sorted.cumsum(dim=0)
denom = torch.arange(
1, N + 1, device=device, dtype=torch.float32)
precision = tp / denom
n_pos_safe = torch.clamp(n_pos, min=1.0)
ap = (precision * y_sorted).sum() / n_pos_safe
out[start + j] = torch.where(n_pos > 0.0,
ap.to(torch.float32), out[start + j])
continue
# exact: group by unique score thresholds
sorted_scores, order = _stable_sort(scores, dim=0, descending=True)
y_sorted = labels.gather(0, order).to(torch.float32)
# group boundaries where score changes
change = torch.empty((N,), device=device, dtype=torch.bool)
change[0] = True
if N > 1:
change[1:] = sorted_scores[1:] != sorted_scores[:-1]
group_starts = change.nonzero(as_tuple=False).squeeze(1)
group_ends = torch.cat(
[group_starts[1:], torch.tensor(
[N], device=device, dtype=group_starts.dtype)]
) - 1
tp = y_sorted.cumsum(dim=0)
fp = torch.arange(1, N + 1, device=device, dtype=torch.float32) - tp
tp_end = tp[group_ends]
fp_end = fp[group_ends]
precision = tp_end / torch.clamp(tp_end + fp_end, min=1.0)
n_pos_safe = torch.clamp(n_pos, min=1.0)
recall = tp_end / n_pos_safe
recall_prev = torch.cat(
[torch.zeros((1,), device=device,
dtype=torch.float32), recall[:-1]]
)
ap = ((recall - recall_prev) * precision).sum()
out[start + j] = torch.where(n_pos > 0.0,
ap.to(torch.float32), out[start + j])
return out
def auroc_per_cause(
y_true: torch.Tensor,
y_score: torch.Tensor,
*,
tie_mode: TieMode = "exact",
chunk_size: int = 128,
) -> torch.Tensor:
"""AUROC per cause via MannWhitney U.
AUC = (sum_ranks_pos - n_pos*(n_pos+1)/2) / (n_pos*n_neg)
tie_mode:
- "exact": average ranks for ties (matches typical sklearn tie behavior)
- "approx": ordinal ranks (faster, differs under ties)
Returns:
(K,) float32 tensor; NaN when a cause has n_pos==0 or n_neg==0.
"""
_validate_binary_inputs(y_true, y_score)
device = y_true.device
N, K = y_true.shape
if N == 0:
return torch.full((K,), float("nan"), device=device, dtype=torch.float32)
yt = y_true.to(torch.bool)
ys = y_score.to(torch.float32)
n_pos_all = yt.sum(dim=0).to(torch.float32)
n_neg_all = (float(N) - n_pos_all).to(torch.float32)
out = torch.full((K,), float("nan"), device=device, dtype=torch.float32)
for start in range(0, K, int(chunk_size)):
end = min(K, start + int(chunk_size))
yt_c = yt[:, start:end]
ys_c = ys[:, start:end]
for j in range(end - start):
n_pos = n_pos_all[start + j]
n_neg = n_neg_all[start + j]
scores = ys_c[:, j]
labels = yt_c[:, j]
sorted_scores, order = _stable_sort(scores, dim=0, descending=False)
y_sorted = labels.gather(0, order).to(torch.float32)
if tie_mode == "approx":
ranks = torch.arange(
1, N + 1, device=device, dtype=torch.float32)
else:
# average ranks for ties
change = torch.empty((N,), device=device, dtype=torch.bool)
change[0] = True
if N > 1:
change[1:] = sorted_scores[1:] != sorted_scores[:-1]
group_starts = change.nonzero(as_tuple=False).squeeze(1)
group_ends = torch.cat(
[group_starts[1:], torch.tensor(
[N], device=device, dtype=group_starts.dtype)]
) - 1
lengths = (group_ends - group_starts + 1).to(torch.long)
start_rank = (group_starts + 1).to(torch.float32)
end_rank = (group_ends + 1).to(torch.float32)
avg_rank = 0.5 * (start_rank + end_rank)
ranks = avg_rank.repeat_interleave(lengths)
sum_ranks_pos = (ranks * y_sorted).sum()
u = sum_ranks_pos - (n_pos * (n_pos + 1.0) / 2.0)
denom = n_pos * n_neg
auc = u / torch.clamp(denom, min=1.0)
valid = (n_pos > 0.0) & (n_neg > 0.0)
out[start + j] = torch.where(valid,
auc.to(torch.float32), out[start + j])
return out
def precision_recall_at_k_percents_per_cause(
y_true: torch.Tensor,
y_score: torch.Tensor,
k_percents: Sequence[float],
*,
chunk_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Precision@K% and Recall@K% per cause.
Uses stable sort (descending) to match deterministic tie behavior.
Returns:
precision: (P,K) float32
recall: (P,K) float32
"""
_validate_binary_inputs(y_true, y_score)
device = y_true.device
N, K = y_true.shape
P = int(len(k_percents))
precision = torch.full((P, K), float(
"nan"), device=device, dtype=torch.float32)
recall = torch.full((P, K), float(
"nan"), device=device, dtype=torch.float32)
if N == 0:
return precision, recall
yt = y_true.to(torch.bool)
ys = y_score.to(torch.float32)
n_pos_all = yt.sum(dim=0).to(torch.float32)
ks: List[int] = []
for kp in k_percents:
k = int(math.ceil((float(kp) / 100.0) * float(N)))
ks.append(k)
for start in range(0, K, int(chunk_size)):
end = min(K, start + int(chunk_size))
yt_c = yt[:, start:end]
ys_c = ys[:, start:end]
for j in range(end - start):
scores = ys_c[:, j]
labels = yt_c[:, j]
n_pos = n_pos_all[start + j]
# stable descending order
_, order = _stable_sort(scores, dim=0, descending=True)
y_sorted = labels.gather(0, order).to(torch.float32)
tp = y_sorted.cumsum(dim=0)
for p_idx, k in enumerate(ks):
if k <= 0:
continue
tp_k = tp[min(k, N) - 1]
precision[p_idx, start + j] = tp_k / float(k)
recall[p_idx, start + j] = torch.where(
n_pos > 0.0,
tp_k / n_pos,
torch.tensor(float("nan"), device=device,
dtype=torch.float32),
)
return precision, recall
@dataclass
class BinaryMetricsResult:
auc_per_cause: torch.Tensor # (K,)
ap_per_cause: torch.Tensor # (K,)
brier_per_cause: torch.Tensor # (K,)
precision_at_k: torch.Tensor # (P,K)
recall_at_k: torch.Tensor # (P,K)
n_pos_per_cause: torch.Tensor # (K,)
n_neg_per_cause: torch.Tensor # (K,)
ici_per_cause: Optional[torch.Tensor] = None # (K,)
@torch.inference_mode()
def compute_binary_metrics_torch(
y_true: torch.Tensor,
y_pred: torch.Tensor,
*,
device: str | torch.device | None = None,
k_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0),
tie_mode: TieMode = "exact",
chunk_size: int = 128,
compute_ici: bool = False,
ici_bins: int = 15,
) -> BinaryMetricsResult:
"""Compute per-cause binary ranking metrics on GPU using torch.
Inputs must be (N,K) and live on the device you want to compute on.
Performance notes:
- Computation is chunked over causes to bound peak memory.
- For `tie_mode="exact"`, AP and AUROC are computed with tie grouping, which
is more correct under ties but uses per-cause loops (still GPU-resident).
Determinism:
- Uses stable sorts where available.
- Avoids nondeterministic selection ops for ties (no `topk`).
"""
_validate_binary_inputs(y_true, y_pred)
if device is not None:
device = torch.device(device)
y_true = y_true.to(device)
y_pred = y_pred.to(device)
N, K = y_true.shape
yt = y_true.to(torch.bool)
yp = y_pred.to(torch.float32)
n_pos = yt.sum(dim=0).to(torch.long)
n_neg = (int(N) - n_pos).to(torch.long)
auc = auroc_per_cause(yt, yp, tie_mode=tie_mode, chunk_size=chunk_size)
ap = average_precision_per_cause(
yt, yp, tie_mode=tie_mode, chunk_size=chunk_size)
brier = brier_per_cause(yt, yp)
prec_k, rec_k = precision_recall_at_k_percents_per_cause(
yt, yp, k_percents, chunk_size=chunk_size
)
ici = None
if compute_ici:
ici = ici_per_cause_fixed_width(
yt, yp, n_bins=int(ici_bins), chunk_size=chunk_size)
return BinaryMetricsResult(
auc_per_cause=auc,
ap_per_cause=ap,
brier_per_cause=brier,
precision_at_k=prec_k,
recall_at_k=rec_k,
n_pos_per_cause=n_pos,
n_neg_per_cause=n_neg,
ici_per_cause=ici,
)
@torch.inference_mode()
def compute_metrics_torch(
y_true: torch.Tensor,
y_pred: torch.Tensor,
*,
device: str | torch.device | None = None,
weights: Optional[torch.Tensor] = None,
k_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0),
tie_mode: TieMode = "exact",
chunk_size: int = 128,
compute_ici: bool = False,
ici_bins: int = 15,
) -> Dict[str, object]:
"""Convenience API: per-cause + macro/weighted aggregations.
Returns a dict compatible with downstream reporting:
- per-cause tensors under `per_cause`
- macro + weighted summaries (NaN-aware)
If `weights` is None, uses number of positives per cause as weights.
"""
res = compute_binary_metrics_torch(
y_true,
y_pred,
device=device,
k_percents=k_percents,
tie_mode=tie_mode,
chunk_size=chunk_size,
compute_ici=compute_ici,
ici_bins=ici_bins,
)
w = res.n_pos_per_cause.to(
torch.float32) if weights is None else weights.to(torch.float32)
out: Dict[str, object] = {
"auc_macro": _nanmean(res.auc_per_cause),
"auc_weighted": _nanweighted_mean(res.auc_per_cause, w),
"ap_macro": _nanmean(res.ap_per_cause),
"ap_weighted": _nanweighted_mean(res.ap_per_cause, w),
"brier_macro": _nanmean(res.brier_per_cause),
"brier_weighted": _nanweighted_mean(res.brier_per_cause, w),
"per_cause": {
"auc": res.auc_per_cause,
"ap": res.ap_per_cause,
"brier": res.brier_per_cause,
"precision_at_k": res.precision_at_k,
"recall_at_k": res.recall_at_k,
"n_pos": res.n_pos_per_cause,
"n_neg": res.n_neg_per_cause,
},
}
if res.ici_per_cause is not None:
out["ici_macro"] = _nanmean(res.ici_per_cause)
out["ici_weighted"] = _nanweighted_mean(res.ici_per_cause, w)
out["per_cause"]["ici"] = res.ici_per_cause
return out