Add binary metrics computation and refactor evaluation logic in age-bin evaluation
This commit is contained in:
@@ -20,6 +20,8 @@ from utils import (
|
||||
sample_context_in_fixed_age_bin,
|
||||
)
|
||||
|
||||
from torch_metrics import compute_binary_metrics_torch
|
||||
|
||||
|
||||
def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Aggregate per-bin age evaluation results.
|
||||
@@ -173,105 +175,8 @@ def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame:
|
||||
return df_agg
|
||||
|
||||
|
||||
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||
"""Compute ROC AUC for binary labels with tie-aware ranking.
|
||||
|
||||
Returns NaN if y_true has no positives or no negatives.
|
||||
|
||||
Uses the Mann–Whitney U statistic with average ranks for ties.
|
||||
"""
|
||||
y_true = np.asarray(y_true).astype(bool)
|
||||
y_score = np.asarray(y_score).astype(float)
|
||||
|
||||
n = y_true.size
|
||||
if n == 0:
|
||||
return float("nan")
|
||||
|
||||
n_pos = int(y_true.sum())
|
||||
n_neg = n - n_pos
|
||||
if n_pos == 0 or n_neg == 0:
|
||||
return float("nan")
|
||||
|
||||
# Rank scores ascending, average ranks for ties.
|
||||
order = np.argsort(y_score, kind="mergesort")
|
||||
sorted_scores = y_score[order]
|
||||
|
||||
ranks = np.empty(n, dtype=float)
|
||||
i = 0
|
||||
# 1-based ranks
|
||||
while i < n:
|
||||
j = i + 1
|
||||
while j < n and sorted_scores[j] == sorted_scores[i]:
|
||||
j += 1
|
||||
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
|
||||
ranks[order[i:j]] = avg_rank
|
||||
i = j
|
||||
|
||||
sum_ranks_pos = float(ranks[y_true].sum())
|
||||
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
|
||||
return float(u / (n_pos * n_neg))
|
||||
|
||||
|
||||
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||
"""Average precision (area under PR curve using step-wise interpolation).
|
||||
|
||||
Returns NaN if no positives.
|
||||
"""
|
||||
y_true = np.asarray(y_true).astype(bool)
|
||||
y_score = np.asarray(y_score).astype(float)
|
||||
|
||||
n = y_true.size
|
||||
if n == 0:
|
||||
return float("nan")
|
||||
|
||||
n_pos = int(y_true.sum())
|
||||
if n_pos == 0:
|
||||
return float("nan")
|
||||
|
||||
order = np.argsort(-y_score, kind="mergesort")
|
||||
y = y_true[order]
|
||||
|
||||
tp = np.cumsum(y).astype(float)
|
||||
fp = np.cumsum(~y).astype(float)
|
||||
precision = tp / np.maximum(tp + fp, 1.0)
|
||||
|
||||
# AP = sum over each positive of precision at that point / n_pos
|
||||
# (equivalent to ∑ Δrecall * precision)
|
||||
ap = float(np.sum(precision[y]) / n_pos)
|
||||
# handle potential tiny numerical overshoots
|
||||
return float(max(0.0, min(1.0, ap)))
|
||||
|
||||
|
||||
def _precision_recall_at_k_percent(
|
||||
y_true: np.ndarray,
|
||||
y_score: np.ndarray,
|
||||
k_percent: float,
|
||||
) -> Tuple[float, float]:
|
||||
"""Precision@K% and Recall@K% for binary labels.
|
||||
|
||||
Returns (precision, recall). Returns NaN for recall if no positives.
|
||||
Returns NaN for precision if k leads to 0 selected.
|
||||
"""
|
||||
y_true = np.asarray(y_true).astype(bool)
|
||||
y_score = np.asarray(y_score).astype(float)
|
||||
|
||||
n = y_true.size
|
||||
if n == 0:
|
||||
return float("nan"), float("nan")
|
||||
|
||||
n_pos = int(y_true.sum())
|
||||
|
||||
k = int(math.ceil((float(k_percent) / 100.0) * n))
|
||||
if k <= 0:
|
||||
return float("nan"), float("nan")
|
||||
|
||||
order = np.argsort(-y_score, kind="mergesort")
|
||||
top = order[:k]
|
||||
tp_top = int(y_true[top].sum())
|
||||
|
||||
precision = tp_top / k
|
||||
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
|
||||
return float(precision), float(recall)
|
||||
# NOTE: metric computation is torch/GPU-native in `torch_metrics.py`.
|
||||
# NumPy/Pandas are only used for final CSV formatting/aggregation.
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -341,18 +246,21 @@ def evaluate_time_dependent_age_bins(
|
||||
list(cfg.cause_ids), dtype=torch.long, device=device)
|
||||
n_causes_eval = int(cause_ids.numel())
|
||||
|
||||
# Storage: (mc, h, bin) -> list of CPU tensors (avoid .numpy() in inner loops)
|
||||
y_true: List[List[List[List[torch.Tensor]]]] = [
|
||||
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
|
||||
for _ in range(int(cfg.n_mc))
|
||||
]
|
||||
y_pred: List[List[List[List[torch.Tensor]]]] = [
|
||||
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
|
||||
for _ in range(int(cfg.n_mc))
|
||||
]
|
||||
rows_by_bin: List[Dict[str, float | int]] = []
|
||||
|
||||
for mc_idx in range(int(cfg.n_mc)):
|
||||
global_mc_idx = int(mc_offset) + int(mc_idx)
|
||||
|
||||
# Storage for this MC only: (tau, bin) -> list of GPU tensors.
|
||||
# This keeps computations GPU-first while preventing a factor-n_mc
|
||||
# blow-up in GPU memory.
|
||||
y_true_mc: List[List[List[torch.Tensor]]] = [
|
||||
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
|
||||
]
|
||||
y_pred_mc: List[List[List[torch.Tensor]]] = [
|
||||
[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))
|
||||
]
|
||||
|
||||
# tqdm over batches; include MC idx in description.
|
||||
for batch_idx, batch in enumerate(
|
||||
tqdm(dataloader,
|
||||
@@ -429,24 +337,17 @@ def evaluate_time_dependent_age_bins(
|
||||
)
|
||||
preds = cifs.index_select(dim=1, index=cause_ids)
|
||||
|
||||
# Reduce CPU/NumPy conversion overhead: keep as CPU torch tensors
|
||||
# and convert to NumPy once during aggregation.
|
||||
y_true[mc_idx][tau_idx][bin_idx].append(
|
||||
y[keep].detach().to(dtype=torch.bool,
|
||||
device="cpu", non_blocking=True)
|
||||
y_true_mc[tau_idx][bin_idx].append(
|
||||
y[keep].detach().to(dtype=torch.bool)
|
||||
)
|
||||
y_pred[mc_idx][tau_idx][bin_idx].append(
|
||||
preds[keep].detach().to(dtype=torch.float32,
|
||||
device="cpu", non_blocking=True)
|
||||
y_pred_mc[tau_idx][bin_idx].append(
|
||||
preds[keep].detach().to(dtype=torch.float32)
|
||||
)
|
||||
|
||||
rows_by_bin: List[Dict[str, float | int]] = []
|
||||
|
||||
for mc_idx in range(int(cfg.n_mc)):
|
||||
global_mc_idx = int(mc_offset) + int(mc_idx)
|
||||
# Aggregate this MC immediately (frees GPU memory before next MC).
|
||||
for h_idx, tau_y in enumerate(horizons_years):
|
||||
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
|
||||
if len(y_true[mc_idx][h_idx][bin_idx]) == 0:
|
||||
if len(y_true_mc[h_idx][bin_idx]) == 0:
|
||||
# No samples in this bin for this (mc, tau)
|
||||
for cause_k in range(n_causes_eval):
|
||||
cause_id = int(cause_k) if cause_ids is None else int(
|
||||
@@ -472,34 +373,37 @@ def evaluate_time_dependent_age_bins(
|
||||
)
|
||||
continue
|
||||
|
||||
yb_t = torch.cat(y_true[mc_idx][h_idx][bin_idx], dim=0)
|
||||
pb_t = torch.cat(y_pred[mc_idx][h_idx][bin_idx], dim=0)
|
||||
yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0)
|
||||
pb_t = torch.cat(y_pred_mc[h_idx][bin_idx], dim=0)
|
||||
if tuple(yb_t.shape) != tuple(pb_t.shape):
|
||||
raise ValueError(
|
||||
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}"
|
||||
)
|
||||
|
||||
yb = yb_t.numpy()
|
||||
pb = pb_t.numpy()
|
||||
n_samples = int(yb_t.size(0))
|
||||
|
||||
n_samples = int(yb.shape[0])
|
||||
metrics = compute_binary_metrics_torch(
|
||||
y_true=yb_t,
|
||||
y_pred=pb_t,
|
||||
k_percents=topk_percents,
|
||||
tie_mode="exact",
|
||||
chunk_size=128,
|
||||
compute_ici=False,
|
||||
)
|
||||
|
||||
# Move just the metric vectors to CPU once per (mc, tau, bin)
|
||||
# for DataFrame construction.
|
||||
auc = metrics.auc_per_cause.detach().cpu().numpy()
|
||||
auprc = metrics.ap_per_cause.detach().cpu().numpy()
|
||||
brier = metrics.brier_per_cause.detach().cpu().numpy()
|
||||
n_pos = metrics.n_pos_per_cause.detach().cpu().numpy()
|
||||
prec_at_k = metrics.precision_at_k.detach().cpu().numpy() # (P,K)
|
||||
rec_at_k = metrics.recall_at_k.detach().cpu().numpy() # (P,K)
|
||||
|
||||
for cause_k in range(n_causes_eval):
|
||||
yk = yb[:, cause_k]
|
||||
pk = pb[:, cause_k]
|
||||
n_pos = int(yk.sum())
|
||||
|
||||
auc = _binary_roc_auc(yk, pk)
|
||||
auprc = _average_precision(yk, pk)
|
||||
brier = float(np.mean(
|
||||
(yk.astype(float) - pk.astype(float)) ** 2)) if n_samples > 0 else float("nan")
|
||||
|
||||
cause_id = int(cause_k) if cause_ids is None else int(
|
||||
cfg.cause_ids[cause_k])
|
||||
|
||||
for k_percent in topk_percents:
|
||||
precision_k, recall_k = _precision_recall_at_k_percent(
|
||||
yk, pk, float(k_percent))
|
||||
for p_idx, k_percent in enumerate(topk_percents):
|
||||
rows_by_bin.append(
|
||||
dict(
|
||||
mc_idx=global_mc_idx,
|
||||
@@ -510,12 +414,12 @@ def evaluate_time_dependent_age_bins(
|
||||
topk_percent=float(k_percent),
|
||||
cause_id=cause_id,
|
||||
n_samples=n_samples,
|
||||
n_positives=n_pos,
|
||||
auc=float(auc),
|
||||
auprc=float(auprc),
|
||||
recall_at_K=float(recall_k),
|
||||
precision_at_K=float(precision_k),
|
||||
brier_score=float(brier),
|
||||
n_positives=int(n_pos[cause_k]),
|
||||
auc=float(auc[cause_k]),
|
||||
auprc=float(auprc[cause_k]),
|
||||
recall_at_K=float(rec_at_k[p_idx, cause_k]),
|
||||
precision_at_K=float(prec_at_k[p_idx, cause_k]),
|
||||
brier_score=float(brier[cause_k]),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user