Add evaluation and utility functions for time-dependent metrics
- Introduced `evaluate.py` for time-dependent evaluation of models, including data loading and model inference. - Added `evaluation_time_dependent.py` to compute various evaluation metrics such as AUC, average precision, and precision/recall at specified thresholds. - Implemented CIF calculation methods in `losses.py` for different loss types, including exponential and piecewise exponential models. - Created utility functions in `utils.py` for context selection and multi-hot encoding of events within specified horizons.
This commit is contained in:
316
evaluation_time_dependent.py
Normal file
316
evaluation_time_dependent.py
Normal file
@@ -0,0 +1,316 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from utils import (
|
||||
DAYS_PER_YEAR,
|
||||
multi_hot_ever_within_horizon,
|
||||
multi_hot_selected_causes_within_horizon,
|
||||
select_context_indices,
|
||||
)
|
||||
|
||||
|
||||
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||
"""Compute ROC AUC for binary labels with tie-aware ranking.
|
||||
|
||||
Returns NaN if y_true has no positives or no negatives.
|
||||
|
||||
Uses the Mann–Whitney U statistic with average ranks for ties.
|
||||
"""
|
||||
y_true = np.asarray(y_true).astype(bool)
|
||||
y_score = np.asarray(y_score).astype(float)
|
||||
|
||||
n = y_true.size
|
||||
if n == 0:
|
||||
return float("nan")
|
||||
|
||||
n_pos = int(y_true.sum())
|
||||
n_neg = n - n_pos
|
||||
if n_pos == 0 or n_neg == 0:
|
||||
return float("nan")
|
||||
|
||||
# Rank scores ascending, average ranks for ties.
|
||||
order = np.argsort(y_score, kind="mergesort")
|
||||
sorted_scores = y_score[order]
|
||||
|
||||
ranks = np.empty(n, dtype=float)
|
||||
i = 0
|
||||
# 1-based ranks
|
||||
while i < n:
|
||||
j = i + 1
|
||||
while j < n and sorted_scores[j] == sorted_scores[i]:
|
||||
j += 1
|
||||
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
|
||||
ranks[order[i:j]] = avg_rank
|
||||
i = j
|
||||
|
||||
sum_ranks_pos = float(ranks[y_true].sum())
|
||||
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
|
||||
return float(u / (n_pos * n_neg))
|
||||
|
||||
|
||||
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||
"""Average precision (area under PR curve using step-wise interpolation).
|
||||
|
||||
Returns NaN if no positives.
|
||||
"""
|
||||
y_true = np.asarray(y_true).astype(bool)
|
||||
y_score = np.asarray(y_score).astype(float)
|
||||
|
||||
n = y_true.size
|
||||
if n == 0:
|
||||
return float("nan")
|
||||
|
||||
n_pos = int(y_true.sum())
|
||||
if n_pos == 0:
|
||||
return float("nan")
|
||||
|
||||
order = np.argsort(-y_score, kind="mergesort")
|
||||
y = y_true[order]
|
||||
|
||||
tp = np.cumsum(y).astype(float)
|
||||
fp = np.cumsum(~y).astype(float)
|
||||
precision = tp / np.maximum(tp + fp, 1.0)
|
||||
recall = tp / n_pos
|
||||
|
||||
# AP = sum over each positive of precision at that point / n_pos
|
||||
# (equivalent to ∑ Δrecall * precision)
|
||||
ap = float(np.sum(precision[y]) / n_pos)
|
||||
# handle potential tiny numerical overshoots
|
||||
return float(max(0.0, min(1.0, ap)))
|
||||
|
||||
|
||||
def _precision_recall_at_k_percent(
|
||||
y_true: np.ndarray,
|
||||
y_score: np.ndarray,
|
||||
k_percent: float,
|
||||
) -> Tuple[float, float]:
|
||||
"""Precision@K% and Recall@K% for binary labels.
|
||||
|
||||
Returns (precision, recall). Returns NaN for recall if no positives.
|
||||
Returns NaN for precision if k leads to 0 selected.
|
||||
"""
|
||||
y_true = np.asarray(y_true).astype(bool)
|
||||
y_score = np.asarray(y_score).astype(float)
|
||||
|
||||
n = y_true.size
|
||||
if n == 0:
|
||||
return float("nan"), float("nan")
|
||||
|
||||
n_pos = int(y_true.sum())
|
||||
|
||||
k = int(math.ceil((float(k_percent) / 100.0) * n))
|
||||
if k <= 0:
|
||||
return float("nan"), float("nan")
|
||||
|
||||
order = np.argsort(-y_score, kind="mergesort")
|
||||
top = order[:k]
|
||||
tp_top = int(y_true[top].sum())
|
||||
|
||||
precision = tp_top / k
|
||||
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
|
||||
return float(precision), float(recall)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalConfig:
|
||||
horizons_years: Sequence[float]
|
||||
offset_years: float = 0.0
|
||||
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
|
||||
cause_ids: Optional[Sequence[int]] = None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_time_dependent(
|
||||
model: torch.nn.Module,
|
||||
head: torch.nn.Module,
|
||||
criterion,
|
||||
dataloader: torch.utils.data.DataLoader,
|
||||
n_disease: int,
|
||||
cfg: EvalConfig,
|
||||
device: str | torch.device,
|
||||
) -> pd.DataFrame:
|
||||
"""Evaluate time-dependent metrics per cause and per horizon.
|
||||
|
||||
Assumptions:
|
||||
- time_seq is in days
|
||||
- horizons_years and the loss CIF times are in years
|
||||
- disease token ids in event_seq are >= 2 and map to cause_id = token_id - 2
|
||||
|
||||
Returns:
|
||||
DataFrame with columns:
|
||||
cause_id, horizon_tau, topk_percent, n_samples, n_positives, auc, auprc,
|
||||
recall_at_K, precision_at_K, brier_score
|
||||
"""
|
||||
device = torch.device(device)
|
||||
model.eval()
|
||||
head.eval()
|
||||
|
||||
horizons_years = [float(x) for x in cfg.horizons_years]
|
||||
if len(horizons_years) == 0:
|
||||
raise ValueError("cfg.horizons_years must be non-empty")
|
||||
|
||||
topk_percents = [float(x) for x in cfg.topk_percents]
|
||||
if len(topk_percents) == 0:
|
||||
raise ValueError("cfg.topk_percents must be non-empty")
|
||||
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
|
||||
raise ValueError(
|
||||
f"All topk_percents must be in (0,100]; got {topk_percents}")
|
||||
|
||||
taus_tensor = torch.tensor(
|
||||
horizons_years, device=device, dtype=torch.float32)
|
||||
|
||||
if cfg.cause_ids is None:
|
||||
cause_ids = None
|
||||
n_causes_eval = int(n_disease)
|
||||
else:
|
||||
cause_ids = torch.tensor(
|
||||
list(cfg.cause_ids), dtype=torch.long, device=device)
|
||||
n_causes_eval = int(cause_ids.numel())
|
||||
|
||||
# Accumulate per horizon
|
||||
y_true_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years]
|
||||
y_pred_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years]
|
||||
|
||||
for batch in dataloader:
|
||||
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
|
||||
event_seq = event_seq.to(device)
|
||||
time_seq = time_seq.to(device)
|
||||
cont_feats = cont_feats.to(device)
|
||||
cate_feats = cate_feats.to(device)
|
||||
sexes = sexes.to(device)
|
||||
|
||||
h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D)
|
||||
|
||||
# Context index selection (independent of horizon); keep mask is refined per horizon.
|
||||
keep0, t_ctx, _ = select_context_indices(
|
||||
event_seq=event_seq,
|
||||
time_seq=time_seq,
|
||||
offset_years=float(cfg.offset_years),
|
||||
tau_years=0.0,
|
||||
)
|
||||
|
||||
if not keep0.any():
|
||||
continue
|
||||
|
||||
b = torch.arange(event_seq.size(0), device=device)
|
||||
c = h[b, t_ctx] # (B,D)
|
||||
logits = head(c)
|
||||
|
||||
# CIFs for all horizons at once
|
||||
cifs_all = criterion.calculate_cifs(
|
||||
logits, taus=taus_tensor) # (B,K,T) or (B,K)
|
||||
if cifs_all.ndim != 3:
|
||||
raise ValueError(
|
||||
f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}"
|
||||
)
|
||||
|
||||
for h_idx, tau_y in enumerate(horizons_years):
|
||||
keep, _, _ = select_context_indices(
|
||||
event_seq=event_seq,
|
||||
time_seq=time_seq,
|
||||
offset_years=float(cfg.offset_years),
|
||||
tau_years=float(tau_y),
|
||||
)
|
||||
keep = keep & keep0
|
||||
if not keep.any():
|
||||
continue
|
||||
|
||||
if cause_ids is None:
|
||||
y = multi_hot_ever_within_horizon(
|
||||
event_seq=event_seq,
|
||||
time_seq=time_seq,
|
||||
t_ctx=t_ctx,
|
||||
tau_years=float(tau_y),
|
||||
n_disease=n_disease,
|
||||
)
|
||||
y = y[keep]
|
||||
preds = cifs_all[keep, :, h_idx]
|
||||
else:
|
||||
y = multi_hot_selected_causes_within_horizon(
|
||||
event_seq=event_seq,
|
||||
time_seq=time_seq,
|
||||
t_ctx=t_ctx,
|
||||
tau_years=float(tau_y),
|
||||
cause_ids=cause_ids,
|
||||
n_disease=n_disease,
|
||||
)
|
||||
y = y[keep]
|
||||
preds = cifs_all[keep, :, h_idx].index_select(
|
||||
dim=1, index=cause_ids)
|
||||
|
||||
y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy())
|
||||
y_pred_by_h[h_idx].append(
|
||||
preds.detach().to(torch.float32).cpu().numpy())
|
||||
|
||||
rows: List[Dict[str, float | int]] = []
|
||||
|
||||
for h_idx, tau_y in enumerate(horizons_years):
|
||||
if len(y_true_by_h[h_idx]) == 0:
|
||||
# No eligible samples for this horizon.
|
||||
for k in range(n_causes_eval):
|
||||
cause_id = int(k) if cause_ids is None else int(
|
||||
cfg.cause_ids[k])
|
||||
for k_percent in topk_percents:
|
||||
rows.append(
|
||||
dict(
|
||||
cause_id=cause_id,
|
||||
horizon_tau=float(tau_y),
|
||||
topk_percent=float(k_percent),
|
||||
n_samples=0,
|
||||
n_positives=0,
|
||||
auc=float("nan"),
|
||||
auprc=float("nan"),
|
||||
recall_at_K=float("nan"),
|
||||
precision_at_K=float("nan"),
|
||||
brier_score=float("nan"),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
y_true = np.concatenate(y_true_by_h[h_idx], axis=0)
|
||||
y_pred = np.concatenate(y_pred_by_h[h_idx], axis=0)
|
||||
|
||||
if y_true.shape != y_pred.shape:
|
||||
raise ValueError(
|
||||
f"Shape mismatch at tau={tau_y}: y_true{tuple(y_true.shape)} vs y_pred{tuple(y_pred.shape)}"
|
||||
)
|
||||
|
||||
n_samples = int(y_true.shape[0])
|
||||
|
||||
for k in range(n_causes_eval):
|
||||
yk = y_true[:, k]
|
||||
pk = y_pred[:, k]
|
||||
n_pos = int(yk.sum())
|
||||
|
||||
auc = _binary_roc_auc(yk, pk)
|
||||
auprc = _average_precision(yk, pk)
|
||||
brier = float(np.mean((yk.astype(float) - pk.astype(float))
|
||||
** 2)) if n_samples > 0 else float("nan")
|
||||
|
||||
cause_id = int(k) if cause_ids is None else int(cfg.cause_ids[k])
|
||||
for k_percent in topk_percents:
|
||||
precision_k, recall_k = _precision_recall_at_k_percent(
|
||||
yk, pk, float(k_percent))
|
||||
rows.append(
|
||||
dict(
|
||||
cause_id=cause_id,
|
||||
horizon_tau=float(tau_y),
|
||||
topk_percent=float(k_percent),
|
||||
n_samples=n_samples,
|
||||
n_positives=n_pos,
|
||||
auc=float(auc),
|
||||
auprc=float(auprc),
|
||||
recall_at_K=float(recall_k),
|
||||
precision_at_K=float(precision_k),
|
||||
brier_score=float(brier),
|
||||
)
|
||||
)
|
||||
|
||||
return pd.DataFrame(rows)
|
||||
Reference in New Issue
Block a user