Files
DeepHealth/evaluation_time_dependent.py
Jiarui Li 34d8d8ce9d Add evaluation and utility functions for time-dependent metrics
- Introduced `evaluate.py` for time-dependent evaluation of models, including data loading and model inference.
- Added `evaluation_time_dependent.py` to compute various evaluation metrics such as AUC, average precision, and precision/recall at specified thresholds.
- Implemented CIF calculation methods in `losses.py` for different loss types, including exponential and piecewise exponential models.
- Created utility functions in `utils.py` for context selection and multi-hot encoding of events within specified horizons.
2026-01-16 14:55:09 +08:00

317 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
import torch
from utils import (
DAYS_PER_YEAR,
multi_hot_ever_within_horizon,
multi_hot_selected_causes_within_horizon,
select_context_indices,
)
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Compute ROC AUC for binary labels with tie-aware ranking.
Returns NaN if y_true has no positives or no negatives.
Uses the MannWhitney U statistic with average ranks for ties.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
n_neg = n - n_pos
if n_pos == 0 or n_neg == 0:
return float("nan")
# Rank scores ascending, average ranks for ties.
order = np.argsort(y_score, kind="mergesort")
sorted_scores = y_score[order]
ranks = np.empty(n, dtype=float)
i = 0
# 1-based ranks
while i < n:
j = i + 1
while j < n and sorted_scores[j] == sorted_scores[i]:
j += 1
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
ranks[order[i:j]] = avg_rank
i = j
sum_ranks_pos = float(ranks[y_true].sum())
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
return float(u / (n_pos * n_neg))
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Average precision (area under PR curve using step-wise interpolation).
Returns NaN if no positives.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
if n_pos == 0:
return float("nan")
order = np.argsort(-y_score, kind="mergesort")
y = y_true[order]
tp = np.cumsum(y).astype(float)
fp = np.cumsum(~y).astype(float)
precision = tp / np.maximum(tp + fp, 1.0)
recall = tp / n_pos
# AP = sum over each positive of precision at that point / n_pos
# (equivalent to ∑ Δrecall * precision)
ap = float(np.sum(precision[y]) / n_pos)
# handle potential tiny numerical overshoots
return float(max(0.0, min(1.0, ap)))
def _precision_recall_at_k_percent(
y_true: np.ndarray,
y_score: np.ndarray,
k_percent: float,
) -> Tuple[float, float]:
"""Precision@K% and Recall@K% for binary labels.
Returns (precision, recall). Returns NaN for recall if no positives.
Returns NaN for precision if k leads to 0 selected.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan"), float("nan")
n_pos = int(y_true.sum())
k = int(math.ceil((float(k_percent) / 100.0) * n))
if k <= 0:
return float("nan"), float("nan")
order = np.argsort(-y_score, kind="mergesort")
top = order[:k]
tp_top = int(y_true[top].sum())
precision = tp_top / k
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
return float(precision), float(recall)
@dataclass
class EvalConfig:
horizons_years: Sequence[float]
offset_years: float = 0.0
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
cause_ids: Optional[Sequence[int]] = None
@torch.no_grad()
def evaluate_time_dependent(
model: torch.nn.Module,
head: torch.nn.Module,
criterion,
dataloader: torch.utils.data.DataLoader,
n_disease: int,
cfg: EvalConfig,
device: str | torch.device,
) -> pd.DataFrame:
"""Evaluate time-dependent metrics per cause and per horizon.
Assumptions:
- time_seq is in days
- horizons_years and the loss CIF times are in years
- disease token ids in event_seq are >= 2 and map to cause_id = token_id - 2
Returns:
DataFrame with columns:
cause_id, horizon_tau, topk_percent, n_samples, n_positives, auc, auprc,
recall_at_K, precision_at_K, brier_score
"""
device = torch.device(device)
model.eval()
head.eval()
horizons_years = [float(x) for x in cfg.horizons_years]
if len(horizons_years) == 0:
raise ValueError("cfg.horizons_years must be non-empty")
topk_percents = [float(x) for x in cfg.topk_percents]
if len(topk_percents) == 0:
raise ValueError("cfg.topk_percents must be non-empty")
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
raise ValueError(
f"All topk_percents must be in (0,100]; got {topk_percents}")
taus_tensor = torch.tensor(
horizons_years, device=device, dtype=torch.float32)
if cfg.cause_ids is None:
cause_ids = None
n_causes_eval = int(n_disease)
else:
cause_ids = torch.tensor(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
# Accumulate per horizon
y_true_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years]
y_pred_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years]
for batch in dataloader:
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D)
# Context index selection (independent of horizon); keep mask is refined per horizon.
keep0, t_ctx, _ = select_context_indices(
event_seq=event_seq,
time_seq=time_seq,
offset_years=float(cfg.offset_years),
tau_years=0.0,
)
if not keep0.any():
continue
b = torch.arange(event_seq.size(0), device=device)
c = h[b, t_ctx] # (B,D)
logits = head(c)
# CIFs for all horizons at once
cifs_all = criterion.calculate_cifs(
logits, taus=taus_tensor) # (B,K,T) or (B,K)
if cifs_all.ndim != 3:
raise ValueError(
f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}"
)
for h_idx, tau_y in enumerate(horizons_years):
keep, _, _ = select_context_indices(
event_seq=event_seq,
time_seq=time_seq,
offset_years=float(cfg.offset_years),
tau_years=float(tau_y),
)
keep = keep & keep0
if not keep.any():
continue
if cause_ids is None:
y = multi_hot_ever_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
n_disease=n_disease,
)
y = y[keep]
preds = cifs_all[keep, :, h_idx]
else:
y = multi_hot_selected_causes_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
cause_ids=cause_ids,
n_disease=n_disease,
)
y = y[keep]
preds = cifs_all[keep, :, h_idx].index_select(
dim=1, index=cause_ids)
y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy())
y_pred_by_h[h_idx].append(
preds.detach().to(torch.float32).cpu().numpy())
rows: List[Dict[str, float | int]] = []
for h_idx, tau_y in enumerate(horizons_years):
if len(y_true_by_h[h_idx]) == 0:
# No eligible samples for this horizon.
for k in range(n_causes_eval):
cause_id = int(k) if cause_ids is None else int(
cfg.cause_ids[k])
for k_percent in topk_percents:
rows.append(
dict(
cause_id=cause_id,
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
n_samples=0,
n_positives=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
continue
y_true = np.concatenate(y_true_by_h[h_idx], axis=0)
y_pred = np.concatenate(y_pred_by_h[h_idx], axis=0)
if y_true.shape != y_pred.shape:
raise ValueError(
f"Shape mismatch at tau={tau_y}: y_true{tuple(y_true.shape)} vs y_pred{tuple(y_pred.shape)}"
)
n_samples = int(y_true.shape[0])
for k in range(n_causes_eval):
yk = y_true[:, k]
pk = y_pred[:, k]
n_pos = int(yk.sum())
auc = _binary_roc_auc(yk, pk)
auprc = _average_precision(yk, pk)
brier = float(np.mean((yk.astype(float) - pk.astype(float))
** 2)) if n_samples > 0 else float("nan")
cause_id = int(k) if cause_ids is None else int(cfg.cause_ids[k])
for k_percent in topk_percents:
precision_k, recall_k = _precision_recall_at_k_percent(
yk, pk, float(k_percent))
rows.append(
dict(
cause_id=cause_id,
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
n_samples=n_samples,
n_positives=n_pos,
auc=float(auc),
auprc=float(auprc),
recall_at_K=float(recall_k),
precision_at_K=float(precision_k),
brier_score=float(brier),
)
)
return pd.DataFrame(rows)