diff --git a/evaluation_age_time_dependent.py b/evaluation_age_time_dependent.py index f788b2a..a4cf12a 100644 --- a/evaluation_age_time_dependent.py +++ b/evaluation_age_time_dependent.py @@ -20,6 +20,8 @@ from utils import ( sample_context_in_fixed_age_bin, ) +from torch_metrics import compute_binary_metrics_torch + def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame: """Aggregate per-bin age evaluation results. @@ -173,105 +175,8 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame: return df_agg -def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float: - """Compute ROC AUC for binary labels with tie-aware ranking. - - Returns NaN if y_true has no positives or no negatives. - - Uses the Mann–Whitney U statistic with average ranks for ties. - """ - y_true = np.asarray(y_true).astype(bool) - y_score = np.asarray(y_score).astype(float) - - n = y_true.size - if n == 0: - return float("nan") - - n_pos = int(y_true.sum()) - n_neg = n - n_pos - if n_pos == 0 or n_neg == 0: - return float("nan") - - # Rank scores ascending, average ranks for ties. - order = np.argsort(y_score, kind="mergesort") - sorted_scores = y_score[order] - - ranks = np.empty(n, dtype=float) - i = 0 - # 1-based ranks - while i < n: - j = i + 1 - while j < n and sorted_scores[j] == sorted_scores[i]: - j += 1 - avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j - ranks[order[i:j]] = avg_rank - i = j - - sum_ranks_pos = float(ranks[y_true].sum()) - u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0) - return float(u / (n_pos * n_neg)) - - -def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float: - """Average precision (area under PR curve using step-wise interpolation). - - Returns NaN if no positives. - """ - y_true = np.asarray(y_true).astype(bool) - y_score = np.asarray(y_score).astype(float) - - n = y_true.size - if n == 0: - return float("nan") - - n_pos = int(y_true.sum()) - if n_pos == 0: - return float("nan") - - order = np.argsort(-y_score, kind="mergesort") - y = y_true[order] - - tp = np.cumsum(y).astype(float) - fp = np.cumsum(~y).astype(float) - precision = tp / np.maximum(tp + fp, 1.0) - - # AP = sum over each positive of precision at that point / n_pos - # (equivalent to ∑ Δrecall * precision) - ap = float(np.sum(precision[y]) / n_pos) - # handle potential tiny numerical overshoots - return float(max(0.0, min(1.0, ap))) - - -def _precision_recall_at_k_percent( - y_true: np.ndarray, - y_score: np.ndarray, - k_percent: float, -) -> Tuple[float, float]: - """Precision@K% and Recall@K% for binary labels. - - Returns (precision, recall). Returns NaN for recall if no positives. - Returns NaN for precision if k leads to 0 selected. - """ - y_true = np.asarray(y_true).astype(bool) - y_score = np.asarray(y_score).astype(float) - - n = y_true.size - if n == 0: - return float("nan"), float("nan") - - n_pos = int(y_true.sum()) - - k = int(math.ceil((float(k_percent) / 100.0) * n)) - if k <= 0: - return float("nan"), float("nan") - - order = np.argsort(-y_score, kind="mergesort") - top = order[:k] - tp_top = int(y_true[top].sum()) - - precision = tp_top / k - recall = float("nan") if n_pos == 0 else (tp_top / n_pos) - return float(precision), float(recall) +# NOTE: metric computation is torch/GPU-native in `torch_metrics.py`. +# NumPy/Pandas are only used for final CSV formatting/aggregation. @dataclass @@ -341,18 +246,21 @@ def evaluate_time_dependent_age_bins( list(cfg.cause_ids), dtype=torch.long, device=device) n_causes_eval = int(cause_ids.numel()) - # Storage: (mc, h, bin) -> list of CPU tensors (avoid .numpy() in inner loops) - y_true: List[List[List[List[torch.Tensor]]]] = [ - [[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))] - for _ in range(int(cfg.n_mc)) - ] - y_pred: List[List[List[List[torch.Tensor]]]] = [ - [[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))] - for _ in range(int(cfg.n_mc)) - ] + rows_by_bin: List[Dict[str, float | int]] = [] for mc_idx in range(int(cfg.n_mc)): global_mc_idx = int(mc_offset) + int(mc_idx) + + # Storage for this MC only: (tau, bin) -> list of GPU tensors. + # This keeps computations GPU-first while preventing a factor-n_mc + # blow-up in GPU memory. + y_true_mc: List[List[List[torch.Tensor]]] = [ + [[] for _ in range(len(age_bins))] for _ in range(len(horizons_years)) + ] + y_pred_mc: List[List[List[torch.Tensor]]] = [ + [[] for _ in range(len(age_bins))] for _ in range(len(horizons_years)) + ] + # tqdm over batches; include MC idx in description. for batch_idx, batch in enumerate( tqdm(dataloader, @@ -429,24 +337,17 @@ def evaluate_time_dependent_age_bins( ) preds = cifs.index_select(dim=1, index=cause_ids) - # Reduce CPU/NumPy conversion overhead: keep as CPU torch tensors - # and convert to NumPy once during aggregation. - y_true[mc_idx][tau_idx][bin_idx].append( - y[keep].detach().to(dtype=torch.bool, - device="cpu", non_blocking=True) + y_true_mc[tau_idx][bin_idx].append( + y[keep].detach().to(dtype=torch.bool) ) - y_pred[mc_idx][tau_idx][bin_idx].append( - preds[keep].detach().to(dtype=torch.float32, - device="cpu", non_blocking=True) + y_pred_mc[tau_idx][bin_idx].append( + preds[keep].detach().to(dtype=torch.float32) ) - rows_by_bin: List[Dict[str, float | int]] = [] - - for mc_idx in range(int(cfg.n_mc)): - global_mc_idx = int(mc_offset) + int(mc_idx) + # Aggregate this MC immediately (frees GPU memory before next MC). for h_idx, tau_y in enumerate(horizons_years): for bin_idx, (a_lo, a_hi) in enumerate(age_bins): - if len(y_true[mc_idx][h_idx][bin_idx]) == 0: + if len(y_true_mc[h_idx][bin_idx]) == 0: # No samples in this bin for this (mc, tau) for cause_k in range(n_causes_eval): cause_id = int(cause_k) if cause_ids is None else int( @@ -472,34 +373,37 @@ def evaluate_time_dependent_age_bins( ) continue - yb_t = torch.cat(y_true[mc_idx][h_idx][bin_idx], dim=0) - pb_t = torch.cat(y_pred[mc_idx][h_idx][bin_idx], dim=0) + yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0) + pb_t = torch.cat(y_pred_mc[h_idx][bin_idx], dim=0) if tuple(yb_t.shape) != tuple(pb_t.shape): raise ValueError( f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}" ) - yb = yb_t.numpy() - pb = pb_t.numpy() + n_samples = int(yb_t.size(0)) - n_samples = int(yb.shape[0]) + metrics = compute_binary_metrics_torch( + y_true=yb_t, + y_pred=pb_t, + k_percents=topk_percents, + tie_mode="exact", + chunk_size=128, + compute_ici=False, + ) + + # Move just the metric vectors to CPU once per (mc, tau, bin) + # for DataFrame construction. + auc = metrics.auc_per_cause.detach().cpu().numpy() + auprc = metrics.ap_per_cause.detach().cpu().numpy() + brier = metrics.brier_per_cause.detach().cpu().numpy() + n_pos = metrics.n_pos_per_cause.detach().cpu().numpy() + prec_at_k = metrics.precision_at_k.detach().cpu().numpy() # (P,K) + rec_at_k = metrics.recall_at_k.detach().cpu().numpy() # (P,K) for cause_k in range(n_causes_eval): - yk = yb[:, cause_k] - pk = pb[:, cause_k] - n_pos = int(yk.sum()) - - auc = _binary_roc_auc(yk, pk) - auprc = _average_precision(yk, pk) - brier = float(np.mean( - (yk.astype(float) - pk.astype(float)) ** 2)) if n_samples > 0 else float("nan") - cause_id = int(cause_k) if cause_ids is None else int( cfg.cause_ids[cause_k]) - - for k_percent in topk_percents: - precision_k, recall_k = _precision_recall_at_k_percent( - yk, pk, float(k_percent)) + for p_idx, k_percent in enumerate(topk_percents): rows_by_bin.append( dict( mc_idx=global_mc_idx, @@ -510,12 +414,12 @@ def evaluate_time_dependent_age_bins( topk_percent=float(k_percent), cause_id=cause_id, n_samples=n_samples, - n_positives=n_pos, - auc=float(auc), - auprc=float(auprc), - recall_at_K=float(recall_k), - precision_at_K=float(precision_k), - brier_score=float(brier), + n_positives=int(n_pos[cause_k]), + auc=float(auc[cause_k]), + auprc=float(auprc[cause_k]), + recall_at_K=float(rec_at_k[p_idx, cause_k]), + precision_at_K=float(prec_at_k[p_idx, cause_k]), + brier_score=float(brier[cause_k]), ) ) diff --git a/torch_metrics.py b/torch_metrics.py new file mode 100644 index 0000000..803b1d1 --- /dev/null +++ b/torch_metrics.py @@ -0,0 +1,524 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Dict, Iterable, List, Literal, Optional, Sequence, Tuple + +import torch + + +TieMode = Literal["exact", "approx"] + + +def _stable_sort( + x: torch.Tensor, + *, + dim: int, + descending: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Stable torch sort when available. + + Determinism notes: + - When `stable=True` is supported by the installed PyTorch, we request it. + - Otherwise we fall back to `torch.sort`. For identical inputs on the same + device/runtime, this is typically deterministic, but tie ordering is not + guaranteed to be stable across versions. + """ + try: + return torch.sort(x, dim=dim, descending=descending, stable=True) + except TypeError: + return torch.sort(x, dim=dim, descending=descending) + + +def _nanmean(x: torch.Tensor) -> torch.Tensor: + mask = torch.isfinite(x) + if not bool(mask.any()): + return torch.tensor(float("nan"), device=x.device, dtype=x.dtype) + return x[mask].mean() + + +def _nanweighted_mean(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + x = x.to(torch.float32) + w = w.to(torch.float32) + mask = torch.isfinite(x) & torch.isfinite(w) & (w > 0) + if not bool(mask.any()): + return torch.tensor(float("nan"), device=x.device, dtype=torch.float32) + ww = w[mask] + return (x[mask] * ww).sum() / ww.sum() + + +def _validate_binary_inputs(y_true: torch.Tensor, y_score: torch.Tensor) -> None: + if y_true.ndim != 2 or y_score.ndim != 2: + raise ValueError( + f"Expected y_true and y_score to be 2D (N,K); got {tuple(y_true.shape)} and {tuple(y_score.shape)}" + ) + if tuple(y_true.shape) != tuple(y_score.shape): + raise ValueError( + f"Shape mismatch: y_true{tuple(y_true.shape)} vs y_score{tuple(y_score.shape)}" + ) + + +def brier_per_cause(y_true: torch.Tensor, y_score: torch.Tensor) -> torch.Tensor: + """Brier score per cause. + + Args: + y_true: (N,K) bool/int tensor + y_score: (N,K) float tensor + + Returns: + (K,) float32 tensor; NaN if N==0. + """ + _validate_binary_inputs(y_true, y_score) + if y_true.numel() == 0: + return torch.full((y_true.size(1),), float("nan"), device=y_true.device, dtype=torch.float32) + yt = y_true.to(torch.float32) + ys = y_score.to(torch.float32) + return ((ys - yt) ** 2).mean(dim=0) + + +def ici_per_cause_fixed_width( + y_true: torch.Tensor, + y_score: torch.Tensor, + *, + n_bins: int = 15, + chunk_size: int = 128, +) -> torch.Tensor: + """Integrated Calibration Index (ICI) via fixed-width bins on [0,1]. + + ICI per cause = E[ |p_bin - y_bin| ] where bin stats are computed over + fixed-width probability bins. + + This is deterministic and GPU-friendly (scatter_add based). + + Returns: + (K,) float32 tensor; NaN when N==0. + """ + _validate_binary_inputs(y_true, y_score) + if int(n_bins) <= 1: + raise ValueError("n_bins must be >= 2") + + device = y_true.device + N, K = y_true.shape + if N == 0: + return torch.full((K,), float("nan"), device=device, dtype=torch.float32) + + yt = y_true.to(torch.float32) + ys = y_score.to(torch.float32).clamp(0.0, 1.0) + + out = torch.full((K,), float("nan"), device=device, dtype=torch.float32) + + for start in range(0, K, int(chunk_size)): + end = min(K, start + int(chunk_size)) + ys_c = ys[:, start:end] + yt_c = yt[:, start:end] + + # bin index in [0, n_bins-1] + bin_idx = torch.clamp( + (ys_c * float(n_bins)).to(torch.long), max=int(n_bins) - 1) + + counts = torch.zeros((int(n_bins), end - start), + device=device, dtype=torch.float32) + pred_sums = torch.zeros_like(counts) + true_sums = torch.zeros_like(counts) + + ones = torch.ones_like(ys_c, dtype=torch.float32) + counts.scatter_add_(0, bin_idx, ones) + pred_sums.scatter_add_(0, bin_idx, ys_c) + true_sums.scatter_add_(0, bin_idx, yt_c) + + denom = counts.clamp(min=1.0) + pred_mean = pred_sums / denom + true_mean = true_sums / denom + + abs_gap = (pred_mean - true_mean).abs() + # sample-weighted average of bin gap + total = counts.sum(dim=0).clamp(min=1.0) + ici = (abs_gap * counts).sum(dim=0) / total + + # If a cause has no samples (shouldn't happen when N>0), mark NaN. + out[start:end] = torch.where( + total > 0, ici.to(torch.float32), out[start:end]) + + return out + + +def average_precision_per_cause( + y_true: torch.Tensor, + y_score: torch.Tensor, + *, + tie_mode: TieMode = "exact", + chunk_size: int = 128, +) -> torch.Tensor: + """Average precision (AP) per cause. + + Definition matches sklearn's `average_precision_score` (step-wise PR integral): + AP = \\sum_i (R_i - R_{i-1}) * P_i + where i iterates over unique score thresholds. + + tie_mode: + - "exact": tie-invariant AP by grouping identical scores (recommended) + - "approx": mean precision at positive ranks; can differ under ties + + Returns: + (K,) float32 tensor with NaN for causes with 0 positives. + """ + _validate_binary_inputs(y_true, y_score) + + device = y_true.device + N, K = y_true.shape + if N == 0: + return torch.full((K,), float("nan"), device=device, dtype=torch.float32) + + yt = y_true.to(torch.bool) + ys = y_score.to(torch.float32) + + n_pos_all = yt.sum(dim=0).to(torch.float32) + out = torch.full((K,), float("nan"), device=device, dtype=torch.float32) + + for start in range(0, K, int(chunk_size)): + end = min(K, start + int(chunk_size)) + yt_c = yt[:, start:end] + ys_c = ys[:, start:end] + + # For exact mode we need per-cause tie grouping; do per-cause loops + # within a chunk to keep memory bounded and stay on GPU. + for j in range(end - start): + n_pos = n_pos_all[start + j] + + scores = ys_c[:, j] + labels = yt_c[:, j] + + if tie_mode == "approx": + _, order = _stable_sort(scores, dim=0, descending=True) + y_sorted = labels.gather(0, order).to(torch.float32) + tp = y_sorted.cumsum(dim=0) + denom = torch.arange( + 1, N + 1, device=device, dtype=torch.float32) + precision = tp / denom + n_pos_safe = torch.clamp(n_pos, min=1.0) + ap = (precision * y_sorted).sum() / n_pos_safe + out[start + j] = torch.where(n_pos > 0.0, + ap.to(torch.float32), out[start + j]) + continue + + # exact: group by unique score thresholds + sorted_scores, order = _stable_sort(scores, dim=0, descending=True) + y_sorted = labels.gather(0, order).to(torch.float32) + + # group boundaries where score changes + change = torch.empty((N,), device=device, dtype=torch.bool) + change[0] = True + if N > 1: + change[1:] = sorted_scores[1:] != sorted_scores[:-1] + group_starts = change.nonzero(as_tuple=False).squeeze(1) + group_ends = torch.cat( + [group_starts[1:], torch.tensor( + [N], device=device, dtype=group_starts.dtype)] + ) - 1 + + tp = y_sorted.cumsum(dim=0) + fp = torch.arange(1, N + 1, device=device, dtype=torch.float32) - tp + + tp_end = tp[group_ends] + fp_end = fp[group_ends] + precision = tp_end / torch.clamp(tp_end + fp_end, min=1.0) + n_pos_safe = torch.clamp(n_pos, min=1.0) + recall = tp_end / n_pos_safe + + recall_prev = torch.cat( + [torch.zeros((1,), device=device, + dtype=torch.float32), recall[:-1]] + ) + ap = ((recall - recall_prev) * precision).sum() + out[start + j] = torch.where(n_pos > 0.0, + ap.to(torch.float32), out[start + j]) + + return out + + +def auroc_per_cause( + y_true: torch.Tensor, + y_score: torch.Tensor, + *, + tie_mode: TieMode = "exact", + chunk_size: int = 128, +) -> torch.Tensor: + """AUROC per cause via Mann–Whitney U. + + AUC = (sum_ranks_pos - n_pos*(n_pos+1)/2) / (n_pos*n_neg) + + tie_mode: + - "exact": average ranks for ties (matches typical sklearn tie behavior) + - "approx": ordinal ranks (faster, differs under ties) + + Returns: + (K,) float32 tensor; NaN when a cause has n_pos==0 or n_neg==0. + """ + _validate_binary_inputs(y_true, y_score) + + device = y_true.device + N, K = y_true.shape + if N == 0: + return torch.full((K,), float("nan"), device=device, dtype=torch.float32) + + yt = y_true.to(torch.bool) + ys = y_score.to(torch.float32) + + n_pos_all = yt.sum(dim=0).to(torch.float32) + n_neg_all = (float(N) - n_pos_all).to(torch.float32) + + out = torch.full((K,), float("nan"), device=device, dtype=torch.float32) + + for start in range(0, K, int(chunk_size)): + end = min(K, start + int(chunk_size)) + yt_c = yt[:, start:end] + ys_c = ys[:, start:end] + + for j in range(end - start): + n_pos = n_pos_all[start + j] + n_neg = n_neg_all[start + j] + + scores = ys_c[:, j] + labels = yt_c[:, j] + + sorted_scores, order = _stable_sort(scores, dim=0, descending=False) + y_sorted = labels.gather(0, order).to(torch.float32) + + if tie_mode == "approx": + ranks = torch.arange( + 1, N + 1, device=device, dtype=torch.float32) + else: + # average ranks for ties + change = torch.empty((N,), device=device, dtype=torch.bool) + change[0] = True + if N > 1: + change[1:] = sorted_scores[1:] != sorted_scores[:-1] + group_starts = change.nonzero(as_tuple=False).squeeze(1) + group_ends = torch.cat( + [group_starts[1:], torch.tensor( + [N], device=device, dtype=group_starts.dtype)] + ) - 1 + + lengths = (group_ends - group_starts + 1).to(torch.long) + start_rank = (group_starts + 1).to(torch.float32) + end_rank = (group_ends + 1).to(torch.float32) + avg_rank = 0.5 * (start_rank + end_rank) + ranks = avg_rank.repeat_interleave(lengths) + + sum_ranks_pos = (ranks * y_sorted).sum() + u = sum_ranks_pos - (n_pos * (n_pos + 1.0) / 2.0) + denom = n_pos * n_neg + auc = u / torch.clamp(denom, min=1.0) + valid = (n_pos > 0.0) & (n_neg > 0.0) + out[start + j] = torch.where(valid, + auc.to(torch.float32), out[start + j]) + + return out + + +def precision_recall_at_k_percents_per_cause( + y_true: torch.Tensor, + y_score: torch.Tensor, + k_percents: Sequence[float], + *, + chunk_size: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Precision@K% and Recall@K% per cause. + + Uses stable sort (descending) to match deterministic tie behavior. + + Returns: + precision: (P,K) float32 + recall: (P,K) float32 + """ + _validate_binary_inputs(y_true, y_score) + + device = y_true.device + N, K = y_true.shape + P = int(len(k_percents)) + + precision = torch.full((P, K), float( + "nan"), device=device, dtype=torch.float32) + recall = torch.full((P, K), float( + "nan"), device=device, dtype=torch.float32) + + if N == 0: + return precision, recall + + yt = y_true.to(torch.bool) + ys = y_score.to(torch.float32) + + n_pos_all = yt.sum(dim=0).to(torch.float32) + + ks: List[int] = [] + for kp in k_percents: + k = int(math.ceil((float(kp) / 100.0) * float(N))) + ks.append(k) + + for start in range(0, K, int(chunk_size)): + end = min(K, start + int(chunk_size)) + yt_c = yt[:, start:end] + ys_c = ys[:, start:end] + + for j in range(end - start): + scores = ys_c[:, j] + labels = yt_c[:, j] + n_pos = n_pos_all[start + j] + + # stable descending order + _, order = _stable_sort(scores, dim=0, descending=True) + y_sorted = labels.gather(0, order).to(torch.float32) + tp = y_sorted.cumsum(dim=0) + + for p_idx, k in enumerate(ks): + if k <= 0: + continue + tp_k = tp[min(k, N) - 1] + precision[p_idx, start + j] = tp_k / float(k) + recall[p_idx, start + j] = torch.where( + n_pos > 0.0, + tp_k / n_pos, + torch.tensor(float("nan"), device=device, + dtype=torch.float32), + ) + + return precision, recall + + +@dataclass +class BinaryMetricsResult: + auc_per_cause: torch.Tensor # (K,) + ap_per_cause: torch.Tensor # (K,) + brier_per_cause: torch.Tensor # (K,) + precision_at_k: torch.Tensor # (P,K) + recall_at_k: torch.Tensor # (P,K) + n_pos_per_cause: torch.Tensor # (K,) + n_neg_per_cause: torch.Tensor # (K,) + ici_per_cause: Optional[torch.Tensor] = None # (K,) + + +@torch.inference_mode() +def compute_binary_metrics_torch( + y_true: torch.Tensor, + y_pred: torch.Tensor, + *, + device: str | torch.device | None = None, + k_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0), + tie_mode: TieMode = "exact", + chunk_size: int = 128, + compute_ici: bool = False, + ici_bins: int = 15, +) -> BinaryMetricsResult: + """Compute per-cause binary ranking metrics on GPU using torch. + + Inputs must be (N,K) and live on the device you want to compute on. + + Performance notes: + - Computation is chunked over causes to bound peak memory. + - For `tie_mode="exact"`, AP and AUROC are computed with tie grouping, which + is more correct under ties but uses per-cause loops (still GPU-resident). + + Determinism: + - Uses stable sorts where available. + - Avoids nondeterministic selection ops for ties (no `topk`). + """ + _validate_binary_inputs(y_true, y_pred) + + if device is not None: + device = torch.device(device) + y_true = y_true.to(device) + y_pred = y_pred.to(device) + + N, K = y_true.shape + + yt = y_true.to(torch.bool) + yp = y_pred.to(torch.float32) + + n_pos = yt.sum(dim=0).to(torch.long) + n_neg = (int(N) - n_pos).to(torch.long) + + auc = auroc_per_cause(yt, yp, tie_mode=tie_mode, chunk_size=chunk_size) + ap = average_precision_per_cause( + yt, yp, tie_mode=tie_mode, chunk_size=chunk_size) + brier = brier_per_cause(yt, yp) + + prec_k, rec_k = precision_recall_at_k_percents_per_cause( + yt, yp, k_percents, chunk_size=chunk_size + ) + + ici = None + if compute_ici: + ici = ici_per_cause_fixed_width( + yt, yp, n_bins=int(ici_bins), chunk_size=chunk_size) + + return BinaryMetricsResult( + auc_per_cause=auc, + ap_per_cause=ap, + brier_per_cause=brier, + precision_at_k=prec_k, + recall_at_k=rec_k, + n_pos_per_cause=n_pos, + n_neg_per_cause=n_neg, + ici_per_cause=ici, + ) + + +@torch.inference_mode() +def compute_metrics_torch( + y_true: torch.Tensor, + y_pred: torch.Tensor, + *, + device: str | torch.device | None = None, + weights: Optional[torch.Tensor] = None, + k_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0), + tie_mode: TieMode = "exact", + chunk_size: int = 128, + compute_ici: bool = False, + ici_bins: int = 15, +) -> Dict[str, object]: + """Convenience API: per-cause + macro/weighted aggregations. + + Returns a dict compatible with downstream reporting: + - per-cause tensors under `per_cause` + - macro + weighted summaries (NaN-aware) + + If `weights` is None, uses number of positives per cause as weights. + """ + res = compute_binary_metrics_torch( + y_true, + y_pred, + device=device, + k_percents=k_percents, + tie_mode=tie_mode, + chunk_size=chunk_size, + compute_ici=compute_ici, + ici_bins=ici_bins, + ) + + w = res.n_pos_per_cause.to( + torch.float32) if weights is None else weights.to(torch.float32) + + out: Dict[str, object] = { + "auc_macro": _nanmean(res.auc_per_cause), + "auc_weighted": _nanweighted_mean(res.auc_per_cause, w), + "ap_macro": _nanmean(res.ap_per_cause), + "ap_weighted": _nanweighted_mean(res.ap_per_cause, w), + "brier_macro": _nanmean(res.brier_per_cause), + "brier_weighted": _nanweighted_mean(res.brier_per_cause, w), + "per_cause": { + "auc": res.auc_per_cause, + "ap": res.ap_per_cause, + "brier": res.brier_per_cause, + "precision_at_k": res.precision_at_k, + "recall_at_k": res.recall_at_k, + "n_pos": res.n_pos_per_cause, + "n_neg": res.n_neg_per_cause, + }, + } + + if res.ici_per_cause is not None: + out["ici_macro"] = _nanmean(res.ici_per_cause) + out["ici_weighted"] = _nanweighted_mean(res.ici_per_cause, w) + out["per_cause"]["ici"] = res.ici_per_cause + + return out