diff --git a/torch_metrics.py b/torch_metrics.py deleted file mode 100644 index 803b1d1..0000000 --- a/torch_metrics.py +++ /dev/null @@ -1,524 +0,0 @@ -from __future__ import annotations - -import math -from dataclasses import dataclass -from typing import Dict, Iterable, List, Literal, Optional, Sequence, Tuple - -import torch - - -TieMode = Literal["exact", "approx"] - - -def _stable_sort( - x: torch.Tensor, - *, - dim: int, - descending: bool, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Stable torch sort when available. - - Determinism notes: - - When `stable=True` is supported by the installed PyTorch, we request it. - - Otherwise we fall back to `torch.sort`. For identical inputs on the same - device/runtime, this is typically deterministic, but tie ordering is not - guaranteed to be stable across versions. - """ - try: - return torch.sort(x, dim=dim, descending=descending, stable=True) - except TypeError: - return torch.sort(x, dim=dim, descending=descending) - - -def _nanmean(x: torch.Tensor) -> torch.Tensor: - mask = torch.isfinite(x) - if not bool(mask.any()): - return torch.tensor(float("nan"), device=x.device, dtype=x.dtype) - return x[mask].mean() - - -def _nanweighted_mean(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: - x = x.to(torch.float32) - w = w.to(torch.float32) - mask = torch.isfinite(x) & torch.isfinite(w) & (w > 0) - if not bool(mask.any()): - return torch.tensor(float("nan"), device=x.device, dtype=torch.float32) - ww = w[mask] - return (x[mask] * ww).sum() / ww.sum() - - -def _validate_binary_inputs(y_true: torch.Tensor, y_score: torch.Tensor) -> None: - if y_true.ndim != 2 or y_score.ndim != 2: - raise ValueError( - f"Expected y_true and y_score to be 2D (N,K); got {tuple(y_true.shape)} and {tuple(y_score.shape)}" - ) - if tuple(y_true.shape) != tuple(y_score.shape): - raise ValueError( - f"Shape mismatch: y_true{tuple(y_true.shape)} vs y_score{tuple(y_score.shape)}" - ) - - -def brier_per_cause(y_true: torch.Tensor, y_score: torch.Tensor) -> torch.Tensor: - """Brier score per cause. - - Args: - y_true: (N,K) bool/int tensor - y_score: (N,K) float tensor - - Returns: - (K,) float32 tensor; NaN if N==0. - """ - _validate_binary_inputs(y_true, y_score) - if y_true.numel() == 0: - return torch.full((y_true.size(1),), float("nan"), device=y_true.device, dtype=torch.float32) - yt = y_true.to(torch.float32) - ys = y_score.to(torch.float32) - return ((ys - yt) ** 2).mean(dim=0) - - -def ici_per_cause_fixed_width( - y_true: torch.Tensor, - y_score: torch.Tensor, - *, - n_bins: int = 15, - chunk_size: int = 128, -) -> torch.Tensor: - """Integrated Calibration Index (ICI) via fixed-width bins on [0,1]. - - ICI per cause = E[ |p_bin - y_bin| ] where bin stats are computed over - fixed-width probability bins. - - This is deterministic and GPU-friendly (scatter_add based). - - Returns: - (K,) float32 tensor; NaN when N==0. - """ - _validate_binary_inputs(y_true, y_score) - if int(n_bins) <= 1: - raise ValueError("n_bins must be >= 2") - - device = y_true.device - N, K = y_true.shape - if N == 0: - return torch.full((K,), float("nan"), device=device, dtype=torch.float32) - - yt = y_true.to(torch.float32) - ys = y_score.to(torch.float32).clamp(0.0, 1.0) - - out = torch.full((K,), float("nan"), device=device, dtype=torch.float32) - - for start in range(0, K, int(chunk_size)): - end = min(K, start + int(chunk_size)) - ys_c = ys[:, start:end] - yt_c = yt[:, start:end] - - # bin index in [0, n_bins-1] - bin_idx = torch.clamp( - (ys_c * float(n_bins)).to(torch.long), max=int(n_bins) - 1) - - counts = torch.zeros((int(n_bins), end - start), - device=device, dtype=torch.float32) - pred_sums = torch.zeros_like(counts) - true_sums = torch.zeros_like(counts) - - ones = torch.ones_like(ys_c, dtype=torch.float32) - counts.scatter_add_(0, bin_idx, ones) - pred_sums.scatter_add_(0, bin_idx, ys_c) - true_sums.scatter_add_(0, bin_idx, yt_c) - - denom = counts.clamp(min=1.0) - pred_mean = pred_sums / denom - true_mean = true_sums / denom - - abs_gap = (pred_mean - true_mean).abs() - # sample-weighted average of bin gap - total = counts.sum(dim=0).clamp(min=1.0) - ici = (abs_gap * counts).sum(dim=0) / total - - # If a cause has no samples (shouldn't happen when N>0), mark NaN. - out[start:end] = torch.where( - total > 0, ici.to(torch.float32), out[start:end]) - - return out - - -def average_precision_per_cause( - y_true: torch.Tensor, - y_score: torch.Tensor, - *, - tie_mode: TieMode = "exact", - chunk_size: int = 128, -) -> torch.Tensor: - """Average precision (AP) per cause. - - Definition matches sklearn's `average_precision_score` (step-wise PR integral): - AP = \\sum_i (R_i - R_{i-1}) * P_i - where i iterates over unique score thresholds. - - tie_mode: - - "exact": tie-invariant AP by grouping identical scores (recommended) - - "approx": mean precision at positive ranks; can differ under ties - - Returns: - (K,) float32 tensor with NaN for causes with 0 positives. - """ - _validate_binary_inputs(y_true, y_score) - - device = y_true.device - N, K = y_true.shape - if N == 0: - return torch.full((K,), float("nan"), device=device, dtype=torch.float32) - - yt = y_true.to(torch.bool) - ys = y_score.to(torch.float32) - - n_pos_all = yt.sum(dim=0).to(torch.float32) - out = torch.full((K,), float("nan"), device=device, dtype=torch.float32) - - for start in range(0, K, int(chunk_size)): - end = min(K, start + int(chunk_size)) - yt_c = yt[:, start:end] - ys_c = ys[:, start:end] - - # For exact mode we need per-cause tie grouping; do per-cause loops - # within a chunk to keep memory bounded and stay on GPU. - for j in range(end - start): - n_pos = n_pos_all[start + j] - - scores = ys_c[:, j] - labels = yt_c[:, j] - - if tie_mode == "approx": - _, order = _stable_sort(scores, dim=0, descending=True) - y_sorted = labels.gather(0, order).to(torch.float32) - tp = y_sorted.cumsum(dim=0) - denom = torch.arange( - 1, N + 1, device=device, dtype=torch.float32) - precision = tp / denom - n_pos_safe = torch.clamp(n_pos, min=1.0) - ap = (precision * y_sorted).sum() / n_pos_safe - out[start + j] = torch.where(n_pos > 0.0, - ap.to(torch.float32), out[start + j]) - continue - - # exact: group by unique score thresholds - sorted_scores, order = _stable_sort(scores, dim=0, descending=True) - y_sorted = labels.gather(0, order).to(torch.float32) - - # group boundaries where score changes - change = torch.empty((N,), device=device, dtype=torch.bool) - change[0] = True - if N > 1: - change[1:] = sorted_scores[1:] != sorted_scores[:-1] - group_starts = change.nonzero(as_tuple=False).squeeze(1) - group_ends = torch.cat( - [group_starts[1:], torch.tensor( - [N], device=device, dtype=group_starts.dtype)] - ) - 1 - - tp = y_sorted.cumsum(dim=0) - fp = torch.arange(1, N + 1, device=device, dtype=torch.float32) - tp - - tp_end = tp[group_ends] - fp_end = fp[group_ends] - precision = tp_end / torch.clamp(tp_end + fp_end, min=1.0) - n_pos_safe = torch.clamp(n_pos, min=1.0) - recall = tp_end / n_pos_safe - - recall_prev = torch.cat( - [torch.zeros((1,), device=device, - dtype=torch.float32), recall[:-1]] - ) - ap = ((recall - recall_prev) * precision).sum() - out[start + j] = torch.where(n_pos > 0.0, - ap.to(torch.float32), out[start + j]) - - return out - - -def auroc_per_cause( - y_true: torch.Tensor, - y_score: torch.Tensor, - *, - tie_mode: TieMode = "exact", - chunk_size: int = 128, -) -> torch.Tensor: - """AUROC per cause via Mann–Whitney U. - - AUC = (sum_ranks_pos - n_pos*(n_pos+1)/2) / (n_pos*n_neg) - - tie_mode: - - "exact": average ranks for ties (matches typical sklearn tie behavior) - - "approx": ordinal ranks (faster, differs under ties) - - Returns: - (K,) float32 tensor; NaN when a cause has n_pos==0 or n_neg==0. - """ - _validate_binary_inputs(y_true, y_score) - - device = y_true.device - N, K = y_true.shape - if N == 0: - return torch.full((K,), float("nan"), device=device, dtype=torch.float32) - - yt = y_true.to(torch.bool) - ys = y_score.to(torch.float32) - - n_pos_all = yt.sum(dim=0).to(torch.float32) - n_neg_all = (float(N) - n_pos_all).to(torch.float32) - - out = torch.full((K,), float("nan"), device=device, dtype=torch.float32) - - for start in range(0, K, int(chunk_size)): - end = min(K, start + int(chunk_size)) - yt_c = yt[:, start:end] - ys_c = ys[:, start:end] - - for j in range(end - start): - n_pos = n_pos_all[start + j] - n_neg = n_neg_all[start + j] - - scores = ys_c[:, j] - labels = yt_c[:, j] - - sorted_scores, order = _stable_sort(scores, dim=0, descending=False) - y_sorted = labels.gather(0, order).to(torch.float32) - - if tie_mode == "approx": - ranks = torch.arange( - 1, N + 1, device=device, dtype=torch.float32) - else: - # average ranks for ties - change = torch.empty((N,), device=device, dtype=torch.bool) - change[0] = True - if N > 1: - change[1:] = sorted_scores[1:] != sorted_scores[:-1] - group_starts = change.nonzero(as_tuple=False).squeeze(1) - group_ends = torch.cat( - [group_starts[1:], torch.tensor( - [N], device=device, dtype=group_starts.dtype)] - ) - 1 - - lengths = (group_ends - group_starts + 1).to(torch.long) - start_rank = (group_starts + 1).to(torch.float32) - end_rank = (group_ends + 1).to(torch.float32) - avg_rank = 0.5 * (start_rank + end_rank) - ranks = avg_rank.repeat_interleave(lengths) - - sum_ranks_pos = (ranks * y_sorted).sum() - u = sum_ranks_pos - (n_pos * (n_pos + 1.0) / 2.0) - denom = n_pos * n_neg - auc = u / torch.clamp(denom, min=1.0) - valid = (n_pos > 0.0) & (n_neg > 0.0) - out[start + j] = torch.where(valid, - auc.to(torch.float32), out[start + j]) - - return out - - -def precision_recall_at_k_percents_per_cause( - y_true: torch.Tensor, - y_score: torch.Tensor, - k_percents: Sequence[float], - *, - chunk_size: int = 128, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Precision@K% and Recall@K% per cause. - - Uses stable sort (descending) to match deterministic tie behavior. - - Returns: - precision: (P,K) float32 - recall: (P,K) float32 - """ - _validate_binary_inputs(y_true, y_score) - - device = y_true.device - N, K = y_true.shape - P = int(len(k_percents)) - - precision = torch.full((P, K), float( - "nan"), device=device, dtype=torch.float32) - recall = torch.full((P, K), float( - "nan"), device=device, dtype=torch.float32) - - if N == 0: - return precision, recall - - yt = y_true.to(torch.bool) - ys = y_score.to(torch.float32) - - n_pos_all = yt.sum(dim=0).to(torch.float32) - - ks: List[int] = [] - for kp in k_percents: - k = int(math.ceil((float(kp) / 100.0) * float(N))) - ks.append(k) - - for start in range(0, K, int(chunk_size)): - end = min(K, start + int(chunk_size)) - yt_c = yt[:, start:end] - ys_c = ys[:, start:end] - - for j in range(end - start): - scores = ys_c[:, j] - labels = yt_c[:, j] - n_pos = n_pos_all[start + j] - - # stable descending order - _, order = _stable_sort(scores, dim=0, descending=True) - y_sorted = labels.gather(0, order).to(torch.float32) - tp = y_sorted.cumsum(dim=0) - - for p_idx, k in enumerate(ks): - if k <= 0: - continue - tp_k = tp[min(k, N) - 1] - precision[p_idx, start + j] = tp_k / float(k) - recall[p_idx, start + j] = torch.where( - n_pos > 0.0, - tp_k / n_pos, - torch.tensor(float("nan"), device=device, - dtype=torch.float32), - ) - - return precision, recall - - -@dataclass -class BinaryMetricsResult: - auc_per_cause: torch.Tensor # (K,) - ap_per_cause: torch.Tensor # (K,) - brier_per_cause: torch.Tensor # (K,) - precision_at_k: torch.Tensor # (P,K) - recall_at_k: torch.Tensor # (P,K) - n_pos_per_cause: torch.Tensor # (K,) - n_neg_per_cause: torch.Tensor # (K,) - ici_per_cause: Optional[torch.Tensor] = None # (K,) - - -@torch.inference_mode() -def compute_binary_metrics_torch( - y_true: torch.Tensor, - y_pred: torch.Tensor, - *, - device: str | torch.device | None = None, - k_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0), - tie_mode: TieMode = "exact", - chunk_size: int = 128, - compute_ici: bool = False, - ici_bins: int = 15, -) -> BinaryMetricsResult: - """Compute per-cause binary ranking metrics on GPU using torch. - - Inputs must be (N,K) and live on the device you want to compute on. - - Performance notes: - - Computation is chunked over causes to bound peak memory. - - For `tie_mode="exact"`, AP and AUROC are computed with tie grouping, which - is more correct under ties but uses per-cause loops (still GPU-resident). - - Determinism: - - Uses stable sorts where available. - - Avoids nondeterministic selection ops for ties (no `topk`). - """ - _validate_binary_inputs(y_true, y_pred) - - if device is not None: - device = torch.device(device) - y_true = y_true.to(device) - y_pred = y_pred.to(device) - - N, K = y_true.shape - - yt = y_true.to(torch.bool) - yp = y_pred.to(torch.float32) - - n_pos = yt.sum(dim=0).to(torch.long) - n_neg = (int(N) - n_pos).to(torch.long) - - auc = auroc_per_cause(yt, yp, tie_mode=tie_mode, chunk_size=chunk_size) - ap = average_precision_per_cause( - yt, yp, tie_mode=tie_mode, chunk_size=chunk_size) - brier = brier_per_cause(yt, yp) - - prec_k, rec_k = precision_recall_at_k_percents_per_cause( - yt, yp, k_percents, chunk_size=chunk_size - ) - - ici = None - if compute_ici: - ici = ici_per_cause_fixed_width( - yt, yp, n_bins=int(ici_bins), chunk_size=chunk_size) - - return BinaryMetricsResult( - auc_per_cause=auc, - ap_per_cause=ap, - brier_per_cause=brier, - precision_at_k=prec_k, - recall_at_k=rec_k, - n_pos_per_cause=n_pos, - n_neg_per_cause=n_neg, - ici_per_cause=ici, - ) - - -@torch.inference_mode() -def compute_metrics_torch( - y_true: torch.Tensor, - y_pred: torch.Tensor, - *, - device: str | torch.device | None = None, - weights: Optional[torch.Tensor] = None, - k_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0), - tie_mode: TieMode = "exact", - chunk_size: int = 128, - compute_ici: bool = False, - ici_bins: int = 15, -) -> Dict[str, object]: - """Convenience API: per-cause + macro/weighted aggregations. - - Returns a dict compatible with downstream reporting: - - per-cause tensors under `per_cause` - - macro + weighted summaries (NaN-aware) - - If `weights` is None, uses number of positives per cause as weights. - """ - res = compute_binary_metrics_torch( - y_true, - y_pred, - device=device, - k_percents=k_percents, - tie_mode=tie_mode, - chunk_size=chunk_size, - compute_ici=compute_ici, - ici_bins=ici_bins, - ) - - w = res.n_pos_per_cause.to( - torch.float32) if weights is None else weights.to(torch.float32) - - out: Dict[str, object] = { - "auc_macro": _nanmean(res.auc_per_cause), - "auc_weighted": _nanweighted_mean(res.auc_per_cause, w), - "ap_macro": _nanmean(res.ap_per_cause), - "ap_weighted": _nanweighted_mean(res.ap_per_cause, w), - "brier_macro": _nanmean(res.brier_per_cause), - "brier_weighted": _nanweighted_mean(res.brier_per_cause, w), - "per_cause": { - "auc": res.auc_per_cause, - "ap": res.ap_per_cause, - "brier": res.brier_per_cause, - "precision_at_k": res.precision_at_k, - "recall_at_k": res.recall_at_k, - "n_pos": res.n_pos_per_cause, - "n_neg": res.n_neg_per_cause, - }, - } - - if res.ici_per_cause is not None: - out["ici_macro"] = _nanmean(res.ici_per_cause) - out["ici_weighted"] = _nanweighted_mean(res.ici_per_cause, w) - out["per_cause"]["ici"] = res.ici_per_cause - - return out