Add binary metrics computation and refactor evaluation logic in age-bin evaluation

This commit is contained in:
2026-01-16 17:19:27 +08:00
parent b1647d1b74
commit 810c75e6d1
2 changed files with 573 additions and 145 deletions

View File

@@ -20,6 +20,8 @@ from utils import (
sample_context_in_fixed_age_bin,
)
from torch_metrics import compute_binary_metrics_torch
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
"""Aggregate per-bin age evaluation results.
@@ -173,105 +175,8 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
return df_agg
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Compute ROC AUC for binary labels with tie-aware ranking.
Returns NaN if y_true has no positives or no negatives.
Uses the MannWhitney U statistic with average ranks for ties.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
n_neg = n - n_pos
if n_pos == 0 or n_neg == 0:
return float("nan")
# Rank scores ascending, average ranks for ties.
order = np.argsort(y_score, kind="mergesort")
sorted_scores = y_score[order]
ranks = np.empty(n, dtype=float)
i = 0
# 1-based ranks
while i < n:
j = i + 1
while j < n and sorted_scores[j] == sorted_scores[i]:
j += 1
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
ranks[order[i:j]] = avg_rank
i = j
sum_ranks_pos = float(ranks[y_true].sum())
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
return float(u / (n_pos * n_neg))
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Average precision (area under PR curve using step-wise interpolation).
Returns NaN if no positives.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
if n_pos == 0:
return float("nan")
order = np.argsort(-y_score, kind="mergesort")
y = y_true[order]
tp = np.cumsum(y).astype(float)
fp = np.cumsum(~y).astype(float)
precision = tp / np.maximum(tp + fp, 1.0)
# AP = sum over each positive of precision at that point / n_pos
# (equivalent to ∑ Δrecall * precision)
ap = float(np.sum(precision[y]) / n_pos)
# handle potential tiny numerical overshoots
return float(max(0.0, min(1.0, ap)))
def _precision_recall_at_k_percent(
y_true: np.ndarray,
y_score: np.ndarray,
k_percent: float,
) -> Tuple[float, float]:
"""Precision@K% and Recall@K% for binary labels.
Returns (precision, recall). Returns NaN for recall if no positives.
Returns NaN for precision if k leads to 0 selected.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan"), float("nan")
n_pos = int(y_true.sum())
k = int(math.ceil((float(k_percent) / 100.0) * n))
if k <= 0:
return float("nan"), float("nan")
order = np.argsort(-y_score, kind="mergesort")
top = order[:k]
tp_top = int(y_true[top].sum())
precision = tp_top / k
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
return float(precision), float(recall)
# NOTE: metric computation is torch/GPU-native in `torch_metrics.py`.
# NumPy/Pandas are only used for final CSV formatting/aggregation.
@dataclass
@@ -341,18 +246,21 @@ def evaluate_time_dependent_age_bins(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
# Storage: (mc, h, bin) -> list of CPU tensors (avoid .numpy() in inner loops)
y_true: List[List[List[List[torch.Tensor]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
y_pred: List[List[List[List[torch.Tensor]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
rows_by_bin: List[Dict[str, float | int]] = []
for mc_idx in range(int(cfg.n_mc)):
global_mc_idx = int(mc_offset) + int(mc_idx)
# Storage for this MC only: (tau, bin) -> list of GPU tensors.
# This keeps computations GPU-first while preventing a factor-n_mc
# blow-up in GPU memory.
y_true_mc: List[List[List[torch.Tensor]]] = [
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
]
y_pred_mc: List[List[List[torch.Tensor]]] = [
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
]
# tqdm over batches; include MC idx in description.
for batch_idx, batch in enumerate(
tqdm(dataloader,
@@ -429,24 +337,17 @@ def evaluate_time_dependent_age_bins(
)
preds = cifs.index_select(dim=1, index=cause_ids)
# Reduce CPU/NumPy conversion overhead: keep as CPU torch tensors
# and convert to NumPy once during aggregation.
y_true[mc_idx][tau_idx][bin_idx].append(
y[keep].detach().to(dtype=torch.bool,
device="cpu", non_blocking=True)
y_true_mc[tau_idx][bin_idx].append(
y[keep].detach().to(dtype=torch.bool)
)
y_pred[mc_idx][tau_idx][bin_idx].append(
preds[keep].detach().to(dtype=torch.float32,
device="cpu", non_blocking=True)
y_pred_mc[tau_idx][bin_idx].append(
preds[keep].detach().to(dtype=torch.float32)
)
rows_by_bin: List[Dict[str, float | int]] = []
for mc_idx in range(int(cfg.n_mc)):
global_mc_idx = int(mc_offset) + int(mc_idx)
# Aggregate this MC immediately (frees GPU memory before next MC).
for h_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
if len(y_true[mc_idx][h_idx][bin_idx]) == 0:
if len(y_true_mc[h_idx][bin_idx]) == 0:
# No samples in this bin for this (mc, tau)
for cause_k in range(n_causes_eval):
cause_id = int(cause_k) if cause_ids is None else int(
@@ -472,34 +373,37 @@ def evaluate_time_dependent_age_bins(
)
continue
yb_t = torch.cat(y_true[mc_idx][h_idx][bin_idx], dim=0)
pb_t = torch.cat(y_pred[mc_idx][h_idx][bin_idx], dim=0)
yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0)
pb_t = torch.cat(y_pred_mc[h_idx][bin_idx], dim=0)
if tuple(yb_t.shape) != tuple(pb_t.shape):
raise ValueError(
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}"
)
yb = yb_t.numpy()
pb = pb_t.numpy()
n_samples = int(yb_t.size(0))
n_samples = int(yb.shape[0])
metrics = compute_binary_metrics_torch(
y_true=yb_t,
y_pred=pb_t,
k_percents=topk_percents,
tie_mode="exact",
chunk_size=128,
compute_ici=False,
)
# Move just the metric vectors to CPU once per (mc, tau, bin)
# for DataFrame construction.
auc = metrics.auc_per_cause.detach().cpu().numpy()
auprc = metrics.ap_per_cause.detach().cpu().numpy()
brier = metrics.brier_per_cause.detach().cpu().numpy()
n_pos = metrics.n_pos_per_cause.detach().cpu().numpy()
prec_at_k = metrics.precision_at_k.detach().cpu().numpy() # (P,K)
rec_at_k = metrics.recall_at_k.detach().cpu().numpy() # (P,K)
for cause_k in range(n_causes_eval):
yk = yb[:, cause_k]
pk = pb[:, cause_k]
n_pos = int(yk.sum())
auc = _binary_roc_auc(yk, pk)
auprc = _average_precision(yk, pk)
brier = float(np.mean(
(yk.astype(float) - pk.astype(float)) ** 2)) if n_samples > 0 else float("nan")
cause_id = int(cause_k) if cause_ids is None else int(
cfg.cause_ids[cause_k])
for k_percent in topk_percents:
precision_k, recall_k = _precision_recall_at_k_percent(
yk, pk, float(k_percent))
for p_idx, k_percent in enumerate(topk_percents):
rows_by_bin.append(
dict(
mc_idx=global_mc_idx,
@@ -510,12 +414,12 @@ def evaluate_time_dependent_age_bins(
topk_percent=float(k_percent),
cause_id=cause_id,
n_samples=n_samples,
n_positives=n_pos,
auc=float(auc),
auprc=float(auprc),
recall_at_K=float(recall_k),
precision_at_K=float(precision_k),
brier_score=float(brier),
n_positives=int(n_pos[cause_k]),
auc=float(auc[cause_k]),
auprc=float(auprc[cause_k]),
recall_at_K=float(rec_at_k[p_idx, cause_k]),
precision_at_K=float(prec_at_k[p_idx, cause_k]),
brier_score=float(brier[cause_k]),
)
)