Files
DeepHealth/evaluation_time_dependent.py

323 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
import torch
from utils import (
DAYS_PER_YEAR,
multi_hot_ever_within_horizon,
multi_hot_selected_causes_within_horizon,
select_context_indices,
)
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Compute ROC AUC for binary labels with tie-aware ranking.
Returns NaN if y_true has no positives or no negatives.
Uses the MannWhitney U statistic with average ranks for ties.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
n_neg = n - n_pos
if n_pos == 0 or n_neg == 0:
return float("nan")
# Rank scores ascending, average ranks for ties.
order = np.argsort(y_score, kind="mergesort")
sorted_scores = y_score[order]
ranks = np.empty(n, dtype=float)
i = 0
# 1-based ranks
while i < n:
j = i + 1
while j < n and sorted_scores[j] == sorted_scores[i]:
j += 1
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
ranks[order[i:j]] = avg_rank
i = j
sum_ranks_pos = float(ranks[y_true].sum())
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
return float(u / (n_pos * n_neg))
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Average precision (area under PR curve using step-wise interpolation).
Returns NaN if no positives.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
if n_pos == 0:
return float("nan")
order = np.argsort(-y_score, kind="mergesort")
y = y_true[order]
tp = np.cumsum(y).astype(float)
fp = np.cumsum(~y).astype(float)
precision = tp / np.maximum(tp + fp, 1.0)
recall = tp / n_pos
# AP = sum over each positive of precision at that point / n_pos
# (equivalent to ∑ Δrecall * precision)
ap = float(np.sum(precision[y]) / n_pos)
# handle potential tiny numerical overshoots
return float(max(0.0, min(1.0, ap)))
def _precision_recall_at_k_percent(
y_true: np.ndarray,
y_score: np.ndarray,
k_percent: float,
) -> Tuple[float, float]:
"""Precision@K% and Recall@K% for binary labels.
Returns (precision, recall). Returns NaN for recall if no positives.
Returns NaN for precision if k leads to 0 selected.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan"), float("nan")
n_pos = int(y_true.sum())
k = int(math.ceil((float(k_percent) / 100.0) * n))
if k <= 0:
return float("nan"), float("nan")
order = np.argsort(-y_score, kind="mergesort")
top = order[:k]
tp_top = int(y_true[top].sum())
precision = tp_top / k
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
return float(precision), float(recall)
@dataclass
class EvalConfig:
horizons_years: Sequence[float]
offset_years: float = 0.0
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
cause_ids: Optional[Sequence[int]] = None
@torch.no_grad()
def evaluate_time_dependent(
model: torch.nn.Module,
head: torch.nn.Module,
criterion,
dataloader: torch.utils.data.DataLoader,
n_disease: int,
cfg: EvalConfig,
device: str | torch.device,
) -> pd.DataFrame:
"""Evaluate time-dependent metrics per cause and per horizon.
Assumptions:
- time_seq is in days
- horizons_years and the loss CIF times are in years
- disease token ids in event_seq are >= 2 and map to cause_id = token_id - 2
Returns:
DataFrame with columns:
cause_id, horizon_tau, topk_percent, n_samples, n_positives, auc, auprc,
recall_at_K, precision_at_K, brier_score
"""
device = torch.device(device)
model.eval()
head.eval()
horizons_years = [float(x) for x in cfg.horizons_years]
if len(horizons_years) == 0:
raise ValueError("cfg.horizons_years must be non-empty")
topk_percents = [float(x) for x in cfg.topk_percents]
if len(topk_percents) == 0:
raise ValueError("cfg.topk_percents must be non-empty")
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
raise ValueError(
f"All topk_percents must be in (0,100]; got {topk_percents}")
taus_tensor = torch.tensor(
horizons_years, device=device, dtype=torch.float32)
if cfg.cause_ids is None:
cause_ids = None
n_causes_eval = int(n_disease)
else:
cause_ids = torch.tensor(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
# Accumulate per horizon
y_true_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years]
y_pred_by_h: List[List[np.ndarray]] = [[] for _ in horizons_years]
for batch in dataloader:
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D)
# Select a single fixed context per sample for this batch.
# Horizon-specific eligibility is derived from this context (do not re-select per horizon).
keep0, t_ctx, t_ctx_time = select_context_indices(
event_seq=event_seq,
time_seq=time_seq,
offset_years=float(cfg.offset_years),
tau_years=0.0,
)
if not keep0.any():
continue
b = torch.arange(event_seq.size(0), device=device)
c = h[b, t_ctx] # (B,D)
logits = head(c)
# CIFs for all horizons at once
cifs_all = criterion.calculate_cifs(
logits, taus=taus_tensor) # (B,K,T) or (B,K)
if cifs_all.ndim != 3:
raise ValueError(
f"criterion.calculate_cifs must return (B,K,T) when taus is (T,), got shape={tuple(cifs_all.shape)}"
)
# Follow-up end time per sample = time at last valid token.
valid = event_seq != 0
lengths = valid.sum(dim=1)
last_idx = torch.clamp(lengths - 1, min=0)
followup_end_time = time_seq[b, last_idx]
for h_idx, tau_y in enumerate(horizons_years):
# Horizon-specific eligibility without reselecting context:
# keep_tau = keep0 & (followup_end_time >= t_ctx_time + tau)
keep_tau = keep0 & (
followup_end_time >= (
t_ctx_time + (float(tau_y) * DAYS_PER_YEAR))
)
if not keep_tau.any():
continue
if cause_ids is None:
y = multi_hot_ever_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
n_disease=n_disease,
)
y = y[keep_tau]
preds = cifs_all[keep_tau, :, h_idx]
else:
y = multi_hot_selected_causes_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
cause_ids=cause_ids,
n_disease=n_disease,
)
y = y[keep_tau]
preds = cifs_all[keep_tau, :, h_idx].index_select(
dim=1, index=cause_ids)
y_true_by_h[h_idx].append(y.detach().to(torch.bool).cpu().numpy())
y_pred_by_h[h_idx].append(
preds.detach().to(torch.float32).cpu().numpy())
rows: List[Dict[str, float | int]] = []
for h_idx, tau_y in enumerate(horizons_years):
if len(y_true_by_h[h_idx]) == 0:
# No eligible samples for this horizon.
for k in range(n_causes_eval):
cause_id = int(k) if cause_ids is None else int(
cfg.cause_ids[k])
for k_percent in topk_percents:
rows.append(
dict(
cause_id=cause_id,
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
n_samples=0,
n_positives=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
continue
y_true = np.concatenate(y_true_by_h[h_idx], axis=0)
y_pred = np.concatenate(y_pred_by_h[h_idx], axis=0)
if y_true.shape != y_pred.shape:
raise ValueError(
f"Shape mismatch at tau={tau_y}: y_true{tuple(y_true.shape)} vs y_pred{tuple(y_pred.shape)}"
)
n_samples = int(y_true.shape[0])
for k in range(n_causes_eval):
yk = y_true[:, k]
pk = y_pred[:, k]
n_pos = int(yk.sum())
auc = _binary_roc_auc(yk, pk)
auprc = _average_precision(yk, pk)
brier = float(np.mean((yk.astype(float) - pk.astype(float))
** 2)) if n_samples > 0 else float("nan")
cause_id = int(k) if cause_ids is None else int(cfg.cause_ids[k])
for k_percent in topk_percents:
precision_k, recall_k = _precision_recall_at_k_percent(
yk, pk, float(k_percent))
rows.append(
dict(
cause_id=cause_id,
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
n_samples=n_samples,
n_positives=n_pos,
auc=float(auc),
auprc=float(auprc),
recall_at_K=float(recall_k),
precision_at_K=float(precision_k),
brier_score=float(brier),
)
)
return pd.DataFrame(rows)