525 lines
17 KiB
Python
525 lines
17 KiB
Python
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import math
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import Dict, Iterable, List, Literal, Optional, Sequence, Tuple
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
|
|||
|
|
|
|||
|
|
TieMode = Literal["exact", "approx"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _stable_sort(
|
|||
|
|
x: torch.Tensor,
|
|||
|
|
*,
|
|||
|
|
dim: int,
|
|||
|
|
descending: bool,
|
|||
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|||
|
|
"""Stable torch sort when available.
|
|||
|
|
|
|||
|
|
Determinism notes:
|
|||
|
|
- When `stable=True` is supported by the installed PyTorch, we request it.
|
|||
|
|
- Otherwise we fall back to `torch.sort`. For identical inputs on the same
|
|||
|
|
device/runtime, this is typically deterministic, but tie ordering is not
|
|||
|
|
guaranteed to be stable across versions.
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
return torch.sort(x, dim=dim, descending=descending, stable=True)
|
|||
|
|
except TypeError:
|
|||
|
|
return torch.sort(x, dim=dim, descending=descending)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _nanmean(x: torch.Tensor) -> torch.Tensor:
|
|||
|
|
mask = torch.isfinite(x)
|
|||
|
|
if not bool(mask.any()):
|
|||
|
|
return torch.tensor(float("nan"), device=x.device, dtype=x.dtype)
|
|||
|
|
return x[mask].mean()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _nanweighted_mean(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
|||
|
|
x = x.to(torch.float32)
|
|||
|
|
w = w.to(torch.float32)
|
|||
|
|
mask = torch.isfinite(x) & torch.isfinite(w) & (w > 0)
|
|||
|
|
if not bool(mask.any()):
|
|||
|
|
return torch.tensor(float("nan"), device=x.device, dtype=torch.float32)
|
|||
|
|
ww = w[mask]
|
|||
|
|
return (x[mask] * ww).sum() / ww.sum()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _validate_binary_inputs(y_true: torch.Tensor, y_score: torch.Tensor) -> None:
|
|||
|
|
if y_true.ndim != 2 or y_score.ndim != 2:
|
|||
|
|
raise ValueError(
|
|||
|
|
f"Expected y_true and y_score to be 2D (N,K); got {tuple(y_true.shape)} and {tuple(y_score.shape)}"
|
|||
|
|
)
|
|||
|
|
if tuple(y_true.shape) != tuple(y_score.shape):
|
|||
|
|
raise ValueError(
|
|||
|
|
f"Shape mismatch: y_true{tuple(y_true.shape)} vs y_score{tuple(y_score.shape)}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def brier_per_cause(y_true: torch.Tensor, y_score: torch.Tensor) -> torch.Tensor:
|
|||
|
|
"""Brier score per cause.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
y_true: (N,K) bool/int tensor
|
|||
|
|
y_score: (N,K) float tensor
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(K,) float32 tensor; NaN if N==0.
|
|||
|
|
"""
|
|||
|
|
_validate_binary_inputs(y_true, y_score)
|
|||
|
|
if y_true.numel() == 0:
|
|||
|
|
return torch.full((y_true.size(1),), float("nan"), device=y_true.device, dtype=torch.float32)
|
|||
|
|
yt = y_true.to(torch.float32)
|
|||
|
|
ys = y_score.to(torch.float32)
|
|||
|
|
return ((ys - yt) ** 2).mean(dim=0)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def ici_per_cause_fixed_width(
|
|||
|
|
y_true: torch.Tensor,
|
|||
|
|
y_score: torch.Tensor,
|
|||
|
|
*,
|
|||
|
|
n_bins: int = 15,
|
|||
|
|
chunk_size: int = 128,
|
|||
|
|
) -> torch.Tensor:
|
|||
|
|
"""Integrated Calibration Index (ICI) via fixed-width bins on [0,1].
|
|||
|
|
|
|||
|
|
ICI per cause = E[ |p_bin - y_bin| ] where bin stats are computed over
|
|||
|
|
fixed-width probability bins.
|
|||
|
|
|
|||
|
|
This is deterministic and GPU-friendly (scatter_add based).
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(K,) float32 tensor; NaN when N==0.
|
|||
|
|
"""
|
|||
|
|
_validate_binary_inputs(y_true, y_score)
|
|||
|
|
if int(n_bins) <= 1:
|
|||
|
|
raise ValueError("n_bins must be >= 2")
|
|||
|
|
|
|||
|
|
device = y_true.device
|
|||
|
|
N, K = y_true.shape
|
|||
|
|
if N == 0:
|
|||
|
|
return torch.full((K,), float("nan"), device=device, dtype=torch.float32)
|
|||
|
|
|
|||
|
|
yt = y_true.to(torch.float32)
|
|||
|
|
ys = y_score.to(torch.float32).clamp(0.0, 1.0)
|
|||
|
|
|
|||
|
|
out = torch.full((K,), float("nan"), device=device, dtype=torch.float32)
|
|||
|
|
|
|||
|
|
for start in range(0, K, int(chunk_size)):
|
|||
|
|
end = min(K, start + int(chunk_size))
|
|||
|
|
ys_c = ys[:, start:end]
|
|||
|
|
yt_c = yt[:, start:end]
|
|||
|
|
|
|||
|
|
# bin index in [0, n_bins-1]
|
|||
|
|
bin_idx = torch.clamp(
|
|||
|
|
(ys_c * float(n_bins)).to(torch.long), max=int(n_bins) - 1)
|
|||
|
|
|
|||
|
|
counts = torch.zeros((int(n_bins), end - start),
|
|||
|
|
device=device, dtype=torch.float32)
|
|||
|
|
pred_sums = torch.zeros_like(counts)
|
|||
|
|
true_sums = torch.zeros_like(counts)
|
|||
|
|
|
|||
|
|
ones = torch.ones_like(ys_c, dtype=torch.float32)
|
|||
|
|
counts.scatter_add_(0, bin_idx, ones)
|
|||
|
|
pred_sums.scatter_add_(0, bin_idx, ys_c)
|
|||
|
|
true_sums.scatter_add_(0, bin_idx, yt_c)
|
|||
|
|
|
|||
|
|
denom = counts.clamp(min=1.0)
|
|||
|
|
pred_mean = pred_sums / denom
|
|||
|
|
true_mean = true_sums / denom
|
|||
|
|
|
|||
|
|
abs_gap = (pred_mean - true_mean).abs()
|
|||
|
|
# sample-weighted average of bin gap
|
|||
|
|
total = counts.sum(dim=0).clamp(min=1.0)
|
|||
|
|
ici = (abs_gap * counts).sum(dim=0) / total
|
|||
|
|
|
|||
|
|
# If a cause has no samples (shouldn't happen when N>0), mark NaN.
|
|||
|
|
out[start:end] = torch.where(
|
|||
|
|
total > 0, ici.to(torch.float32), out[start:end])
|
|||
|
|
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
|
|||
|
|
def average_precision_per_cause(
|
|||
|
|
y_true: torch.Tensor,
|
|||
|
|
y_score: torch.Tensor,
|
|||
|
|
*,
|
|||
|
|
tie_mode: TieMode = "exact",
|
|||
|
|
chunk_size: int = 128,
|
|||
|
|
) -> torch.Tensor:
|
|||
|
|
"""Average precision (AP) per cause.
|
|||
|
|
|
|||
|
|
Definition matches sklearn's `average_precision_score` (step-wise PR integral):
|
|||
|
|
AP = \\sum_i (R_i - R_{i-1}) * P_i
|
|||
|
|
where i iterates over unique score thresholds.
|
|||
|
|
|
|||
|
|
tie_mode:
|
|||
|
|
- "exact": tie-invariant AP by grouping identical scores (recommended)
|
|||
|
|
- "approx": mean precision at positive ranks; can differ under ties
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(K,) float32 tensor with NaN for causes with 0 positives.
|
|||
|
|
"""
|
|||
|
|
_validate_binary_inputs(y_true, y_score)
|
|||
|
|
|
|||
|
|
device = y_true.device
|
|||
|
|
N, K = y_true.shape
|
|||
|
|
if N == 0:
|
|||
|
|
return torch.full((K,), float("nan"), device=device, dtype=torch.float32)
|
|||
|
|
|
|||
|
|
yt = y_true.to(torch.bool)
|
|||
|
|
ys = y_score.to(torch.float32)
|
|||
|
|
|
|||
|
|
n_pos_all = yt.sum(dim=0).to(torch.float32)
|
|||
|
|
out = torch.full((K,), float("nan"), device=device, dtype=torch.float32)
|
|||
|
|
|
|||
|
|
for start in range(0, K, int(chunk_size)):
|
|||
|
|
end = min(K, start + int(chunk_size))
|
|||
|
|
yt_c = yt[:, start:end]
|
|||
|
|
ys_c = ys[:, start:end]
|
|||
|
|
|
|||
|
|
# For exact mode we need per-cause tie grouping; do per-cause loops
|
|||
|
|
# within a chunk to keep memory bounded and stay on GPU.
|
|||
|
|
for j in range(end - start):
|
|||
|
|
n_pos = n_pos_all[start + j]
|
|||
|
|
|
|||
|
|
scores = ys_c[:, j]
|
|||
|
|
labels = yt_c[:, j]
|
|||
|
|
|
|||
|
|
if tie_mode == "approx":
|
|||
|
|
_, order = _stable_sort(scores, dim=0, descending=True)
|
|||
|
|
y_sorted = labels.gather(0, order).to(torch.float32)
|
|||
|
|
tp = y_sorted.cumsum(dim=0)
|
|||
|
|
denom = torch.arange(
|
|||
|
|
1, N + 1, device=device, dtype=torch.float32)
|
|||
|
|
precision = tp / denom
|
|||
|
|
n_pos_safe = torch.clamp(n_pos, min=1.0)
|
|||
|
|
ap = (precision * y_sorted).sum() / n_pos_safe
|
|||
|
|
out[start + j] = torch.where(n_pos > 0.0,
|
|||
|
|
ap.to(torch.float32), out[start + j])
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# exact: group by unique score thresholds
|
|||
|
|
sorted_scores, order = _stable_sort(scores, dim=0, descending=True)
|
|||
|
|
y_sorted = labels.gather(0, order).to(torch.float32)
|
|||
|
|
|
|||
|
|
# group boundaries where score changes
|
|||
|
|
change = torch.empty((N,), device=device, dtype=torch.bool)
|
|||
|
|
change[0] = True
|
|||
|
|
if N > 1:
|
|||
|
|
change[1:] = sorted_scores[1:] != sorted_scores[:-1]
|
|||
|
|
group_starts = change.nonzero(as_tuple=False).squeeze(1)
|
|||
|
|
group_ends = torch.cat(
|
|||
|
|
[group_starts[1:], torch.tensor(
|
|||
|
|
[N], device=device, dtype=group_starts.dtype)]
|
|||
|
|
) - 1
|
|||
|
|
|
|||
|
|
tp = y_sorted.cumsum(dim=0)
|
|||
|
|
fp = torch.arange(1, N + 1, device=device, dtype=torch.float32) - tp
|
|||
|
|
|
|||
|
|
tp_end = tp[group_ends]
|
|||
|
|
fp_end = fp[group_ends]
|
|||
|
|
precision = tp_end / torch.clamp(tp_end + fp_end, min=1.0)
|
|||
|
|
n_pos_safe = torch.clamp(n_pos, min=1.0)
|
|||
|
|
recall = tp_end / n_pos_safe
|
|||
|
|
|
|||
|
|
recall_prev = torch.cat(
|
|||
|
|
[torch.zeros((1,), device=device,
|
|||
|
|
dtype=torch.float32), recall[:-1]]
|
|||
|
|
)
|
|||
|
|
ap = ((recall - recall_prev) * precision).sum()
|
|||
|
|
out[start + j] = torch.where(n_pos > 0.0,
|
|||
|
|
ap.to(torch.float32), out[start + j])
|
|||
|
|
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
|
|||
|
|
def auroc_per_cause(
|
|||
|
|
y_true: torch.Tensor,
|
|||
|
|
y_score: torch.Tensor,
|
|||
|
|
*,
|
|||
|
|
tie_mode: TieMode = "exact",
|
|||
|
|
chunk_size: int = 128,
|
|||
|
|
) -> torch.Tensor:
|
|||
|
|
"""AUROC per cause via Mann–Whitney U.
|
|||
|
|
|
|||
|
|
AUC = (sum_ranks_pos - n_pos*(n_pos+1)/2) / (n_pos*n_neg)
|
|||
|
|
|
|||
|
|
tie_mode:
|
|||
|
|
- "exact": average ranks for ties (matches typical sklearn tie behavior)
|
|||
|
|
- "approx": ordinal ranks (faster, differs under ties)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(K,) float32 tensor; NaN when a cause has n_pos==0 or n_neg==0.
|
|||
|
|
"""
|
|||
|
|
_validate_binary_inputs(y_true, y_score)
|
|||
|
|
|
|||
|
|
device = y_true.device
|
|||
|
|
N, K = y_true.shape
|
|||
|
|
if N == 0:
|
|||
|
|
return torch.full((K,), float("nan"), device=device, dtype=torch.float32)
|
|||
|
|
|
|||
|
|
yt = y_true.to(torch.bool)
|
|||
|
|
ys = y_score.to(torch.float32)
|
|||
|
|
|
|||
|
|
n_pos_all = yt.sum(dim=0).to(torch.float32)
|
|||
|
|
n_neg_all = (float(N) - n_pos_all).to(torch.float32)
|
|||
|
|
|
|||
|
|
out = torch.full((K,), float("nan"), device=device, dtype=torch.float32)
|
|||
|
|
|
|||
|
|
for start in range(0, K, int(chunk_size)):
|
|||
|
|
end = min(K, start + int(chunk_size))
|
|||
|
|
yt_c = yt[:, start:end]
|
|||
|
|
ys_c = ys[:, start:end]
|
|||
|
|
|
|||
|
|
for j in range(end - start):
|
|||
|
|
n_pos = n_pos_all[start + j]
|
|||
|
|
n_neg = n_neg_all[start + j]
|
|||
|
|
|
|||
|
|
scores = ys_c[:, j]
|
|||
|
|
labels = yt_c[:, j]
|
|||
|
|
|
|||
|
|
sorted_scores, order = _stable_sort(scores, dim=0, descending=False)
|
|||
|
|
y_sorted = labels.gather(0, order).to(torch.float32)
|
|||
|
|
|
|||
|
|
if tie_mode == "approx":
|
|||
|
|
ranks = torch.arange(
|
|||
|
|
1, N + 1, device=device, dtype=torch.float32)
|
|||
|
|
else:
|
|||
|
|
# average ranks for ties
|
|||
|
|
change = torch.empty((N,), device=device, dtype=torch.bool)
|
|||
|
|
change[0] = True
|
|||
|
|
if N > 1:
|
|||
|
|
change[1:] = sorted_scores[1:] != sorted_scores[:-1]
|
|||
|
|
group_starts = change.nonzero(as_tuple=False).squeeze(1)
|
|||
|
|
group_ends = torch.cat(
|
|||
|
|
[group_starts[1:], torch.tensor(
|
|||
|
|
[N], device=device, dtype=group_starts.dtype)]
|
|||
|
|
) - 1
|
|||
|
|
|
|||
|
|
lengths = (group_ends - group_starts + 1).to(torch.long)
|
|||
|
|
start_rank = (group_starts + 1).to(torch.float32)
|
|||
|
|
end_rank = (group_ends + 1).to(torch.float32)
|
|||
|
|
avg_rank = 0.5 * (start_rank + end_rank)
|
|||
|
|
ranks = avg_rank.repeat_interleave(lengths)
|
|||
|
|
|
|||
|
|
sum_ranks_pos = (ranks * y_sorted).sum()
|
|||
|
|
u = sum_ranks_pos - (n_pos * (n_pos + 1.0) / 2.0)
|
|||
|
|
denom = n_pos * n_neg
|
|||
|
|
auc = u / torch.clamp(denom, min=1.0)
|
|||
|
|
valid = (n_pos > 0.0) & (n_neg > 0.0)
|
|||
|
|
out[start + j] = torch.where(valid,
|
|||
|
|
auc.to(torch.float32), out[start + j])
|
|||
|
|
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
|
|||
|
|
def precision_recall_at_k_percents_per_cause(
|
|||
|
|
y_true: torch.Tensor,
|
|||
|
|
y_score: torch.Tensor,
|
|||
|
|
k_percents: Sequence[float],
|
|||
|
|
*,
|
|||
|
|
chunk_size: int = 128,
|
|||
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|||
|
|
"""Precision@K% and Recall@K% per cause.
|
|||
|
|
|
|||
|
|
Uses stable sort (descending) to match deterministic tie behavior.
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
precision: (P,K) float32
|
|||
|
|
recall: (P,K) float32
|
|||
|
|
"""
|
|||
|
|
_validate_binary_inputs(y_true, y_score)
|
|||
|
|
|
|||
|
|
device = y_true.device
|
|||
|
|
N, K = y_true.shape
|
|||
|
|
P = int(len(k_percents))
|
|||
|
|
|
|||
|
|
precision = torch.full((P, K), float(
|
|||
|
|
"nan"), device=device, dtype=torch.float32)
|
|||
|
|
recall = torch.full((P, K), float(
|
|||
|
|
"nan"), device=device, dtype=torch.float32)
|
|||
|
|
|
|||
|
|
if N == 0:
|
|||
|
|
return precision, recall
|
|||
|
|
|
|||
|
|
yt = y_true.to(torch.bool)
|
|||
|
|
ys = y_score.to(torch.float32)
|
|||
|
|
|
|||
|
|
n_pos_all = yt.sum(dim=0).to(torch.float32)
|
|||
|
|
|
|||
|
|
ks: List[int] = []
|
|||
|
|
for kp in k_percents:
|
|||
|
|
k = int(math.ceil((float(kp) / 100.0) * float(N)))
|
|||
|
|
ks.append(k)
|
|||
|
|
|
|||
|
|
for start in range(0, K, int(chunk_size)):
|
|||
|
|
end = min(K, start + int(chunk_size))
|
|||
|
|
yt_c = yt[:, start:end]
|
|||
|
|
ys_c = ys[:, start:end]
|
|||
|
|
|
|||
|
|
for j in range(end - start):
|
|||
|
|
scores = ys_c[:, j]
|
|||
|
|
labels = yt_c[:, j]
|
|||
|
|
n_pos = n_pos_all[start + j]
|
|||
|
|
|
|||
|
|
# stable descending order
|
|||
|
|
_, order = _stable_sort(scores, dim=0, descending=True)
|
|||
|
|
y_sorted = labels.gather(0, order).to(torch.float32)
|
|||
|
|
tp = y_sorted.cumsum(dim=0)
|
|||
|
|
|
|||
|
|
for p_idx, k in enumerate(ks):
|
|||
|
|
if k <= 0:
|
|||
|
|
continue
|
|||
|
|
tp_k = tp[min(k, N) - 1]
|
|||
|
|
precision[p_idx, start + j] = tp_k / float(k)
|
|||
|
|
recall[p_idx, start + j] = torch.where(
|
|||
|
|
n_pos > 0.0,
|
|||
|
|
tp_k / n_pos,
|
|||
|
|
torch.tensor(float("nan"), device=device,
|
|||
|
|
dtype=torch.float32),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return precision, recall
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class BinaryMetricsResult:
|
|||
|
|
auc_per_cause: torch.Tensor # (K,)
|
|||
|
|
ap_per_cause: torch.Tensor # (K,)
|
|||
|
|
brier_per_cause: torch.Tensor # (K,)
|
|||
|
|
precision_at_k: torch.Tensor # (P,K)
|
|||
|
|
recall_at_k: torch.Tensor # (P,K)
|
|||
|
|
n_pos_per_cause: torch.Tensor # (K,)
|
|||
|
|
n_neg_per_cause: torch.Tensor # (K,)
|
|||
|
|
ici_per_cause: Optional[torch.Tensor] = None # (K,)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@torch.inference_mode()
|
|||
|
|
def compute_binary_metrics_torch(
|
|||
|
|
y_true: torch.Tensor,
|
|||
|
|
y_pred: torch.Tensor,
|
|||
|
|
*,
|
|||
|
|
device: str | torch.device | None = None,
|
|||
|
|
k_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0),
|
|||
|
|
tie_mode: TieMode = "exact",
|
|||
|
|
chunk_size: int = 128,
|
|||
|
|
compute_ici: bool = False,
|
|||
|
|
ici_bins: int = 15,
|
|||
|
|
) -> BinaryMetricsResult:
|
|||
|
|
"""Compute per-cause binary ranking metrics on GPU using torch.
|
|||
|
|
|
|||
|
|
Inputs must be (N,K) and live on the device you want to compute on.
|
|||
|
|
|
|||
|
|
Performance notes:
|
|||
|
|
- Computation is chunked over causes to bound peak memory.
|
|||
|
|
- For `tie_mode="exact"`, AP and AUROC are computed with tie grouping, which
|
|||
|
|
is more correct under ties but uses per-cause loops (still GPU-resident).
|
|||
|
|
|
|||
|
|
Determinism:
|
|||
|
|
- Uses stable sorts where available.
|
|||
|
|
- Avoids nondeterministic selection ops for ties (no `topk`).
|
|||
|
|
"""
|
|||
|
|
_validate_binary_inputs(y_true, y_pred)
|
|||
|
|
|
|||
|
|
if device is not None:
|
|||
|
|
device = torch.device(device)
|
|||
|
|
y_true = y_true.to(device)
|
|||
|
|
y_pred = y_pred.to(device)
|
|||
|
|
|
|||
|
|
N, K = y_true.shape
|
|||
|
|
|
|||
|
|
yt = y_true.to(torch.bool)
|
|||
|
|
yp = y_pred.to(torch.float32)
|
|||
|
|
|
|||
|
|
n_pos = yt.sum(dim=0).to(torch.long)
|
|||
|
|
n_neg = (int(N) - n_pos).to(torch.long)
|
|||
|
|
|
|||
|
|
auc = auroc_per_cause(yt, yp, tie_mode=tie_mode, chunk_size=chunk_size)
|
|||
|
|
ap = average_precision_per_cause(
|
|||
|
|
yt, yp, tie_mode=tie_mode, chunk_size=chunk_size)
|
|||
|
|
brier = brier_per_cause(yt, yp)
|
|||
|
|
|
|||
|
|
prec_k, rec_k = precision_recall_at_k_percents_per_cause(
|
|||
|
|
yt, yp, k_percents, chunk_size=chunk_size
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
ici = None
|
|||
|
|
if compute_ici:
|
|||
|
|
ici = ici_per_cause_fixed_width(
|
|||
|
|
yt, yp, n_bins=int(ici_bins), chunk_size=chunk_size)
|
|||
|
|
|
|||
|
|
return BinaryMetricsResult(
|
|||
|
|
auc_per_cause=auc,
|
|||
|
|
ap_per_cause=ap,
|
|||
|
|
brier_per_cause=brier,
|
|||
|
|
precision_at_k=prec_k,
|
|||
|
|
recall_at_k=rec_k,
|
|||
|
|
n_pos_per_cause=n_pos,
|
|||
|
|
n_neg_per_cause=n_neg,
|
|||
|
|
ici_per_cause=ici,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@torch.inference_mode()
|
|||
|
|
def compute_metrics_torch(
|
|||
|
|
y_true: torch.Tensor,
|
|||
|
|
y_pred: torch.Tensor,
|
|||
|
|
*,
|
|||
|
|
device: str | torch.device | None = None,
|
|||
|
|
weights: Optional[torch.Tensor] = None,
|
|||
|
|
k_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0),
|
|||
|
|
tie_mode: TieMode = "exact",
|
|||
|
|
chunk_size: int = 128,
|
|||
|
|
compute_ici: bool = False,
|
|||
|
|
ici_bins: int = 15,
|
|||
|
|
) -> Dict[str, object]:
|
|||
|
|
"""Convenience API: per-cause + macro/weighted aggregations.
|
|||
|
|
|
|||
|
|
Returns a dict compatible with downstream reporting:
|
|||
|
|
- per-cause tensors under `per_cause`
|
|||
|
|
- macro + weighted summaries (NaN-aware)
|
|||
|
|
|
|||
|
|
If `weights` is None, uses number of positives per cause as weights.
|
|||
|
|
"""
|
|||
|
|
res = compute_binary_metrics_torch(
|
|||
|
|
y_true,
|
|||
|
|
y_pred,
|
|||
|
|
device=device,
|
|||
|
|
k_percents=k_percents,
|
|||
|
|
tie_mode=tie_mode,
|
|||
|
|
chunk_size=chunk_size,
|
|||
|
|
compute_ici=compute_ici,
|
|||
|
|
ici_bins=ici_bins,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
w = res.n_pos_per_cause.to(
|
|||
|
|
torch.float32) if weights is None else weights.to(torch.float32)
|
|||
|
|
|
|||
|
|
out: Dict[str, object] = {
|
|||
|
|
"auc_macro": _nanmean(res.auc_per_cause),
|
|||
|
|
"auc_weighted": _nanweighted_mean(res.auc_per_cause, w),
|
|||
|
|
"ap_macro": _nanmean(res.ap_per_cause),
|
|||
|
|
"ap_weighted": _nanweighted_mean(res.ap_per_cause, w),
|
|||
|
|
"brier_macro": _nanmean(res.brier_per_cause),
|
|||
|
|
"brier_weighted": _nanweighted_mean(res.brier_per_cause, w),
|
|||
|
|
"per_cause": {
|
|||
|
|
"auc": res.auc_per_cause,
|
|||
|
|
"ap": res.ap_per_cause,
|
|||
|
|
"brier": res.brier_per_cause,
|
|||
|
|
"precision_at_k": res.precision_at_k,
|
|||
|
|
"recall_at_k": res.recall_at_k,
|
|||
|
|
"n_pos": res.n_pos_per_cause,
|
|||
|
|
"n_neg": res.n_neg_per_cause,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if res.ici_per_cause is not None:
|
|||
|
|
out["ici_macro"] = _nanmean(res.ici_per_cause)
|
|||
|
|
out["ici_weighted"] = _nanweighted_mean(res.ici_per_cause, w)
|
|||
|
|
out["per_cause"]["ici"] = res.ici_per_cause
|
|||
|
|
|
|||
|
|
return out
|