Files
DeepHealth/torch_metrics.py

525 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Dict, Iterable, List, Literal, Optional, Sequence, Tuple
import torch
TieMode = Literal["exact", "approx"]
def _stable_sort(
x: torch.Tensor,
*,
dim: int,
descending: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Stable torch sort when available.
Determinism notes:
- When `stable=True` is supported by the installed PyTorch, we request it.
- Otherwise we fall back to `torch.sort`. For identical inputs on the same
device/runtime, this is typically deterministic, but tie ordering is not
guaranteed to be stable across versions.
"""
try:
return torch.sort(x, dim=dim, descending=descending, stable=True)
except TypeError:
return torch.sort(x, dim=dim, descending=descending)
def _nanmean(x: torch.Tensor) -> torch.Tensor:
mask = torch.isfinite(x)
if not bool(mask.any()):
return torch.tensor(float("nan"), device=x.device, dtype=x.dtype)
return x[mask].mean()
def _nanweighted_mean(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
x = x.to(torch.float32)
w = w.to(torch.float32)
mask = torch.isfinite(x) & torch.isfinite(w) & (w > 0)
if not bool(mask.any()):
return torch.tensor(float("nan"), device=x.device, dtype=torch.float32)
ww = w[mask]
return (x[mask] * ww).sum() / ww.sum()
def _validate_binary_inputs(y_true: torch.Tensor, y_score: torch.Tensor) -> None:
if y_true.ndim != 2 or y_score.ndim != 2:
raise ValueError(
f"Expected y_true and y_score to be 2D (N,K); got {tuple(y_true.shape)} and {tuple(y_score.shape)}"
)
if tuple(y_true.shape) != tuple(y_score.shape):
raise ValueError(
f"Shape mismatch: y_true{tuple(y_true.shape)} vs y_score{tuple(y_score.shape)}"
)
def brier_per_cause(y_true: torch.Tensor, y_score: torch.Tensor) -> torch.Tensor:
"""Brier score per cause.
Args:
y_true: (N,K) bool/int tensor
y_score: (N,K) float tensor
Returns:
(K,) float32 tensor; NaN if N==0.
"""
_validate_binary_inputs(y_true, y_score)
if y_true.numel() == 0:
return torch.full((y_true.size(1),), float("nan"), device=y_true.device, dtype=torch.float32)
yt = y_true.to(torch.float32)
ys = y_score.to(torch.float32)
return ((ys - yt) ** 2).mean(dim=0)
def ici_per_cause_fixed_width(
y_true: torch.Tensor,
y_score: torch.Tensor,
*,
n_bins: int = 15,
chunk_size: int = 128,
) -> torch.Tensor:
"""Integrated Calibration Index (ICI) via fixed-width bins on [0,1].
ICI per cause = E[ |p_bin - y_bin| ] where bin stats are computed over
fixed-width probability bins.
This is deterministic and GPU-friendly (scatter_add based).
Returns:
(K,) float32 tensor; NaN when N==0.
"""
_validate_binary_inputs(y_true, y_score)
if int(n_bins) <= 1:
raise ValueError("n_bins must be >= 2")
device = y_true.device
N, K = y_true.shape
if N == 0:
return torch.full((K,), float("nan"), device=device, dtype=torch.float32)
yt = y_true.to(torch.float32)
ys = y_score.to(torch.float32).clamp(0.0, 1.0)
out = torch.full((K,), float("nan"), device=device, dtype=torch.float32)
for start in range(0, K, int(chunk_size)):
end = min(K, start + int(chunk_size))
ys_c = ys[:, start:end]
yt_c = yt[:, start:end]
# bin index in [0, n_bins-1]
bin_idx = torch.clamp(
(ys_c * float(n_bins)).to(torch.long), max=int(n_bins) - 1)
counts = torch.zeros((int(n_bins), end - start),
device=device, dtype=torch.float32)
pred_sums = torch.zeros_like(counts)
true_sums = torch.zeros_like(counts)
ones = torch.ones_like(ys_c, dtype=torch.float32)
counts.scatter_add_(0, bin_idx, ones)
pred_sums.scatter_add_(0, bin_idx, ys_c)
true_sums.scatter_add_(0, bin_idx, yt_c)
denom = counts.clamp(min=1.0)
pred_mean = pred_sums / denom
true_mean = true_sums / denom
abs_gap = (pred_mean - true_mean).abs()
# sample-weighted average of bin gap
total = counts.sum(dim=0).clamp(min=1.0)
ici = (abs_gap * counts).sum(dim=0) / total
# If a cause has no samples (shouldn't happen when N>0), mark NaN.
out[start:end] = torch.where(
total > 0, ici.to(torch.float32), out[start:end])
return out
def average_precision_per_cause(
y_true: torch.Tensor,
y_score: torch.Tensor,
*,
tie_mode: TieMode = "exact",
chunk_size: int = 128,
) -> torch.Tensor:
"""Average precision (AP) per cause.
Definition matches sklearn's `average_precision_score` (step-wise PR integral):
AP = \\sum_i (R_i - R_{i-1}) * P_i
where i iterates over unique score thresholds.
tie_mode:
- "exact": tie-invariant AP by grouping identical scores (recommended)
- "approx": mean precision at positive ranks; can differ under ties
Returns:
(K,) float32 tensor with NaN for causes with 0 positives.
"""
_validate_binary_inputs(y_true, y_score)
device = y_true.device
N, K = y_true.shape
if N == 0:
return torch.full((K,), float("nan"), device=device, dtype=torch.float32)
yt = y_true.to(torch.bool)
ys = y_score.to(torch.float32)
n_pos_all = yt.sum(dim=0).to(torch.float32)
out = torch.full((K,), float("nan"), device=device, dtype=torch.float32)
for start in range(0, K, int(chunk_size)):
end = min(K, start + int(chunk_size))
yt_c = yt[:, start:end]
ys_c = ys[:, start:end]
# For exact mode we need per-cause tie grouping; do per-cause loops
# within a chunk to keep memory bounded and stay on GPU.
for j in range(end - start):
n_pos = n_pos_all[start + j]
scores = ys_c[:, j]
labels = yt_c[:, j]
if tie_mode == "approx":
_, order = _stable_sort(scores, dim=0, descending=True)
y_sorted = labels.gather(0, order).to(torch.float32)
tp = y_sorted.cumsum(dim=0)
denom = torch.arange(
1, N + 1, device=device, dtype=torch.float32)
precision = tp / denom
n_pos_safe = torch.clamp(n_pos, min=1.0)
ap = (precision * y_sorted).sum() / n_pos_safe
out[start + j] = torch.where(n_pos > 0.0,
ap.to(torch.float32), out[start + j])
continue
# exact: group by unique score thresholds
sorted_scores, order = _stable_sort(scores, dim=0, descending=True)
y_sorted = labels.gather(0, order).to(torch.float32)
# group boundaries where score changes
change = torch.empty((N,), device=device, dtype=torch.bool)
change[0] = True
if N > 1:
change[1:] = sorted_scores[1:] != sorted_scores[:-1]
group_starts = change.nonzero(as_tuple=False).squeeze(1)
group_ends = torch.cat(
[group_starts[1:], torch.tensor(
[N], device=device, dtype=group_starts.dtype)]
) - 1
tp = y_sorted.cumsum(dim=0)
fp = torch.arange(1, N + 1, device=device, dtype=torch.float32) - tp
tp_end = tp[group_ends]
fp_end = fp[group_ends]
precision = tp_end / torch.clamp(tp_end + fp_end, min=1.0)
n_pos_safe = torch.clamp(n_pos, min=1.0)
recall = tp_end / n_pos_safe
recall_prev = torch.cat(
[torch.zeros((1,), device=device,
dtype=torch.float32), recall[:-1]]
)
ap = ((recall - recall_prev) * precision).sum()
out[start + j] = torch.where(n_pos > 0.0,
ap.to(torch.float32), out[start + j])
return out
def auroc_per_cause(
y_true: torch.Tensor,
y_score: torch.Tensor,
*,
tie_mode: TieMode = "exact",
chunk_size: int = 128,
) -> torch.Tensor:
"""AUROC per cause via MannWhitney U.
AUC = (sum_ranks_pos - n_pos*(n_pos+1)/2) / (n_pos*n_neg)
tie_mode:
- "exact": average ranks for ties (matches typical sklearn tie behavior)
- "approx": ordinal ranks (faster, differs under ties)
Returns:
(K,) float32 tensor; NaN when a cause has n_pos==0 or n_neg==0.
"""
_validate_binary_inputs(y_true, y_score)
device = y_true.device
N, K = y_true.shape
if N == 0:
return torch.full((K,), float("nan"), device=device, dtype=torch.float32)
yt = y_true.to(torch.bool)
ys = y_score.to(torch.float32)
n_pos_all = yt.sum(dim=0).to(torch.float32)
n_neg_all = (float(N) - n_pos_all).to(torch.float32)
out = torch.full((K,), float("nan"), device=device, dtype=torch.float32)
for start in range(0, K, int(chunk_size)):
end = min(K, start + int(chunk_size))
yt_c = yt[:, start:end]
ys_c = ys[:, start:end]
for j in range(end - start):
n_pos = n_pos_all[start + j]
n_neg = n_neg_all[start + j]
scores = ys_c[:, j]
labels = yt_c[:, j]
sorted_scores, order = _stable_sort(scores, dim=0, descending=False)
y_sorted = labels.gather(0, order).to(torch.float32)
if tie_mode == "approx":
ranks = torch.arange(
1, N + 1, device=device, dtype=torch.float32)
else:
# average ranks for ties
change = torch.empty((N,), device=device, dtype=torch.bool)
change[0] = True
if N > 1:
change[1:] = sorted_scores[1:] != sorted_scores[:-1]
group_starts = change.nonzero(as_tuple=False).squeeze(1)
group_ends = torch.cat(
[group_starts[1:], torch.tensor(
[N], device=device, dtype=group_starts.dtype)]
) - 1
lengths = (group_ends - group_starts + 1).to(torch.long)
start_rank = (group_starts + 1).to(torch.float32)
end_rank = (group_ends + 1).to(torch.float32)
avg_rank = 0.5 * (start_rank + end_rank)
ranks = avg_rank.repeat_interleave(lengths)
sum_ranks_pos = (ranks * y_sorted).sum()
u = sum_ranks_pos - (n_pos * (n_pos + 1.0) / 2.0)
denom = n_pos * n_neg
auc = u / torch.clamp(denom, min=1.0)
valid = (n_pos > 0.0) & (n_neg > 0.0)
out[start + j] = torch.where(valid,
auc.to(torch.float32), out[start + j])
return out
def precision_recall_at_k_percents_per_cause(
y_true: torch.Tensor,
y_score: torch.Tensor,
k_percents: Sequence[float],
*,
chunk_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Precision@K% and Recall@K% per cause.
Uses stable sort (descending) to match deterministic tie behavior.
Returns:
precision: (P,K) float32
recall: (P,K) float32
"""
_validate_binary_inputs(y_true, y_score)
device = y_true.device
N, K = y_true.shape
P = int(len(k_percents))
precision = torch.full((P, K), float(
"nan"), device=device, dtype=torch.float32)
recall = torch.full((P, K), float(
"nan"), device=device, dtype=torch.float32)
if N == 0:
return precision, recall
yt = y_true.to(torch.bool)
ys = y_score.to(torch.float32)
n_pos_all = yt.sum(dim=0).to(torch.float32)
ks: List[int] = []
for kp in k_percents:
k = int(math.ceil((float(kp) / 100.0) * float(N)))
ks.append(k)
for start in range(0, K, int(chunk_size)):
end = min(K, start + int(chunk_size))
yt_c = yt[:, start:end]
ys_c = ys[:, start:end]
for j in range(end - start):
scores = ys_c[:, j]
labels = yt_c[:, j]
n_pos = n_pos_all[start + j]
# stable descending order
_, order = _stable_sort(scores, dim=0, descending=True)
y_sorted = labels.gather(0, order).to(torch.float32)
tp = y_sorted.cumsum(dim=0)
for p_idx, k in enumerate(ks):
if k <= 0:
continue
tp_k = tp[min(k, N) - 1]
precision[p_idx, start + j] = tp_k / float(k)
recall[p_idx, start + j] = torch.where(
n_pos > 0.0,
tp_k / n_pos,
torch.tensor(float("nan"), device=device,
dtype=torch.float32),
)
return precision, recall
@dataclass
class BinaryMetricsResult:
auc_per_cause: torch.Tensor # (K,)
ap_per_cause: torch.Tensor # (K,)
brier_per_cause: torch.Tensor # (K,)
precision_at_k: torch.Tensor # (P,K)
recall_at_k: torch.Tensor # (P,K)
n_pos_per_cause: torch.Tensor # (K,)
n_neg_per_cause: torch.Tensor # (K,)
ici_per_cause: Optional[torch.Tensor] = None # (K,)
@torch.inference_mode()
def compute_binary_metrics_torch(
y_true: torch.Tensor,
y_pred: torch.Tensor,
*,
device: str | torch.device | None = None,
k_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0),
tie_mode: TieMode = "exact",
chunk_size: int = 128,
compute_ici: bool = False,
ici_bins: int = 15,
) -> BinaryMetricsResult:
"""Compute per-cause binary ranking metrics on GPU using torch.
Inputs must be (N,K) and live on the device you want to compute on.
Performance notes:
- Computation is chunked over causes to bound peak memory.
- For `tie_mode="exact"`, AP and AUROC are computed with tie grouping, which
is more correct under ties but uses per-cause loops (still GPU-resident).
Determinism:
- Uses stable sorts where available.
- Avoids nondeterministic selection ops for ties (no `topk`).
"""
_validate_binary_inputs(y_true, y_pred)
if device is not None:
device = torch.device(device)
y_true = y_true.to(device)
y_pred = y_pred.to(device)
N, K = y_true.shape
yt = y_true.to(torch.bool)
yp = y_pred.to(torch.float32)
n_pos = yt.sum(dim=0).to(torch.long)
n_neg = (int(N) - n_pos).to(torch.long)
auc = auroc_per_cause(yt, yp, tie_mode=tie_mode, chunk_size=chunk_size)
ap = average_precision_per_cause(
yt, yp, tie_mode=tie_mode, chunk_size=chunk_size)
brier = brier_per_cause(yt, yp)
prec_k, rec_k = precision_recall_at_k_percents_per_cause(
yt, yp, k_percents, chunk_size=chunk_size
)
ici = None
if compute_ici:
ici = ici_per_cause_fixed_width(
yt, yp, n_bins=int(ici_bins), chunk_size=chunk_size)
return BinaryMetricsResult(
auc_per_cause=auc,
ap_per_cause=ap,
brier_per_cause=brier,
precision_at_k=prec_k,
recall_at_k=rec_k,
n_pos_per_cause=n_pos,
n_neg_per_cause=n_neg,
ici_per_cause=ici,
)
@torch.inference_mode()
def compute_metrics_torch(
y_true: torch.Tensor,
y_pred: torch.Tensor,
*,
device: str | torch.device | None = None,
weights: Optional[torch.Tensor] = None,
k_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0),
tie_mode: TieMode = "exact",
chunk_size: int = 128,
compute_ici: bool = False,
ici_bins: int = 15,
) -> Dict[str, object]:
"""Convenience API: per-cause + macro/weighted aggregations.
Returns a dict compatible with downstream reporting:
- per-cause tensors under `per_cause`
- macro + weighted summaries (NaN-aware)
If `weights` is None, uses number of positives per cause as weights.
"""
res = compute_binary_metrics_torch(
y_true,
y_pred,
device=device,
k_percents=k_percents,
tie_mode=tie_mode,
chunk_size=chunk_size,
compute_ici=compute_ici,
ici_bins=ici_bins,
)
w = res.n_pos_per_cause.to(
torch.float32) if weights is None else weights.to(torch.float32)
out: Dict[str, object] = {
"auc_macro": _nanmean(res.auc_per_cause),
"auc_weighted": _nanweighted_mean(res.auc_per_cause, w),
"ap_macro": _nanmean(res.ap_per_cause),
"ap_weighted": _nanweighted_mean(res.ap_per_cause, w),
"brier_macro": _nanmean(res.brier_per_cause),
"brier_weighted": _nanweighted_mean(res.brier_per_cause, w),
"per_cause": {
"auc": res.auc_per_cause,
"ap": res.ap_per_cause,
"brier": res.brier_per_cause,
"precision_at_k": res.precision_at_k,
"recall_at_k": res.recall_at_k,
"n_pos": res.n_pos_per_cause,
"n_neg": res.n_neg_per_cause,
},
}
if res.ici_per_cause is not None:
out["ici_macro"] = _nanmean(res.ici_per_cause)
out["ici_weighted"] = _nanweighted_mean(res.ici_per_cause, w)
out["per_cause"]["ici"] = res.ici_per_cause
return out