Files
DeepHealth/evaluation_age_time_dependent.py

470 lines
17 KiB
Python
Raw Normal View History

from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
import torch
try:
from tqdm import tqdm
except Exception: # pragma: no cover
def tqdm(x, **kwargs):
return x
from utils import (
multi_hot_ever_within_horizon,
multi_hot_selected_causes_within_horizon,
sample_context_in_fixed_age_bin,
)
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Compute ROC AUC for binary labels with tie-aware ranking.
Returns NaN if y_true has no positives or no negatives.
Uses the MannWhitney U statistic with average ranks for ties.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
n_neg = n - n_pos
if n_pos == 0 or n_neg == 0:
return float("nan")
# Rank scores ascending, average ranks for ties.
order = np.argsort(y_score, kind="mergesort")
sorted_scores = y_score[order]
ranks = np.empty(n, dtype=float)
i = 0
# 1-based ranks
while i < n:
j = i + 1
while j < n and sorted_scores[j] == sorted_scores[i]:
j += 1
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
ranks[order[i:j]] = avg_rank
i = j
sum_ranks_pos = float(ranks[y_true].sum())
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
return float(u / (n_pos * n_neg))
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Average precision (area under PR curve using step-wise interpolation).
Returns NaN if no positives.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
if n_pos == 0:
return float("nan")
order = np.argsort(-y_score, kind="mergesort")
y = y_true[order]
tp = np.cumsum(y).astype(float)
fp = np.cumsum(~y).astype(float)
precision = tp / np.maximum(tp + fp, 1.0)
# AP = sum over each positive of precision at that point / n_pos
# (equivalent to ∑ Δrecall * precision)
ap = float(np.sum(precision[y]) / n_pos)
# handle potential tiny numerical overshoots
return float(max(0.0, min(1.0, ap)))
def _precision_recall_at_k_percent(
y_true: np.ndarray,
y_score: np.ndarray,
k_percent: float,
) -> Tuple[float, float]:
"""Precision@K% and Recall@K% for binary labels.
Returns (precision, recall). Returns NaN for recall if no positives.
Returns NaN for precision if k leads to 0 selected.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan"), float("nan")
n_pos = int(y_true.sum())
k = int(math.ceil((float(k_percent) / 100.0) * n))
if k <= 0:
return float("nan"), float("nan")
order = np.argsort(-y_score, kind="mergesort")
top = order[:k]
tp_top = int(y_true[top].sum())
precision = tp_top / k
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
return float(precision), float(recall)
@dataclass
class EvalAgeConfig:
horizons_years: Sequence[float]
age_bins: Sequence[Tuple[float, float]]
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
n_mc: int = 5
seed: int = 0
cause_ids: Optional[Sequence[int]] = None
@torch.no_grad()
def evaluate_time_dependent_age_bins(
model: torch.nn.Module,
head: torch.nn.Module,
criterion,
dataloader: torch.utils.data.DataLoader,
n_disease: int,
cfg: EvalAgeConfig,
device: str | torch.device,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Delphi-2M-style age-bin evaluation with strict horizon alignment.
Semantics (strict): for each (MC, horizon tau, age bin) we independently:
- build the eligible token set within that bin
- enforce follow-up coverage: t_ctx + tau <= t_end
- randomly sample exactly one token per individual within the bin (de-dup)
- recompute context representations and predictions for that (tau, bin)
Returns:
df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id)
df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted}
"""
device = torch.device(device)
model.eval()
head.eval()
horizons_years = [float(x) for x in cfg.horizons_years]
if len(horizons_years) == 0:
raise ValueError("cfg.horizons_years must be non-empty")
age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins]
if len(age_bins) == 0:
raise ValueError("cfg.age_bins must be non-empty")
for (a, b) in age_bins:
if not (b > a):
raise ValueError(
f"age_bins must be (low, high) with high>low; got {(a, b)}")
topk_percents = [float(x) for x in cfg.topk_percents]
if len(topk_percents) == 0:
raise ValueError("cfg.topk_percents must be non-empty")
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
raise ValueError(
f"All topk_percents must be in (0,100]; got {topk_percents}")
if int(cfg.n_mc) <= 0:
raise ValueError("cfg.n_mc must be >= 1")
if cfg.cause_ids is None:
cause_ids = None
n_causes_eval = int(n_disease)
else:
cause_ids = torch.tensor(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
# Storage: (mc, h, bin) -> list of arrays
y_true: List[List[List[List[np.ndarray]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
y_pred: List[List[List[List[np.ndarray]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
for mc_idx in range(int(cfg.n_mc)):
# tqdm over batches; include MC idx in description.
for batch_idx, batch in enumerate(
tqdm(dataloader,
desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch")
):
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
B = int(event_seq.size(0))
b = torch.arange(B, device=device)
for tau_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
seed = (
int(cfg.seed)
+ (100_000 * int(mc_idx))
+ (1_000 * int(tau_idx))
+ (10 * int(bin_idx))
+ int(batch_idx)
)
keep, t_ctx = sample_context_in_fixed_age_bin(
event_seq=event_seq,
time_seq=time_seq,
tau_years=float(tau_y),
age_bin=(float(a_lo), float(a_hi)),
seed=seed,
)
if not keep.any():
continue
# Strict bin-specific prediction: recompute representations and logits per (tau, bin).
h = model(event_seq, time_seq, sexes,
cont_feats, cate_feats) # (B,L,D)
c = h[b, t_ctx]
logits = head(c)
cifs = criterion.calculate_cifs(
logits, taus=torch.tensor(float(tau_y), device=device)
)
if cifs.ndim != 2:
raise ValueError(
"criterion.calculate_cifs must return (B,K) for scalar tau; "
f"got shape={tuple(cifs.shape)}"
)
if cause_ids is None:
y = multi_hot_ever_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
n_disease=n_disease,
)
preds = cifs
else:
y = multi_hot_selected_causes_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
cause_ids=cause_ids,
n_disease=n_disease,
)
preds = cifs.index_select(dim=1, index=cause_ids)
y_true[mc_idx][tau_idx][bin_idx].append(
y[keep].detach().to(torch.bool).cpu().numpy()
)
y_pred[mc_idx][tau_idx][bin_idx].append(
preds[keep].detach().to(torch.float32).cpu().numpy()
)
rows_by_bin: List[Dict[str, float | int]] = []
for mc_idx in range(int(cfg.n_mc)):
for h_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
if len(y_true[mc_idx][h_idx][bin_idx]) == 0:
# No samples in this bin for this (mc, tau)
for cause_k in range(n_causes_eval):
cause_id = int(cause_k) if cause_ids is None else int(
cfg.cause_ids[cause_k])
for k_percent in topk_percents:
rows_by_bin.append(
dict(
mc_idx=mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
cause_id=cause_id,
n_samples=0,
n_positives=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
continue
yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0)
pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0)
if yb.shape != pb.shape:
raise ValueError(
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}"
)
n_samples = int(yb.shape[0])
for cause_k in range(n_causes_eval):
yk = yb[:, cause_k]
pk = pb[:, cause_k]
n_pos = int(yk.sum())
auc = _binary_roc_auc(yk, pk)
auprc = _average_precision(yk, pk)
brier = float(np.mean(
(yk.astype(float) - pk.astype(float)) ** 2)) if n_samples > 0 else float("nan")
cause_id = int(cause_k) if cause_ids is None else int(
cfg.cause_ids[cause_k])
for k_percent in topk_percents:
precision_k, recall_k = _precision_recall_at_k_percent(
yk, pk, float(k_percent))
rows_by_bin.append(
dict(
mc_idx=mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
cause_id=cause_id,
n_samples=n_samples,
n_positives=n_pos,
auc=float(auc),
auprc=float(auprc),
recall_at_K=float(recall_k),
precision_at_K=float(precision_k),
brier_score=float(brier),
)
)
df_by_bin = pd.DataFrame(rows_by_bin)
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
g = group[group["n_samples"] > 0]
if len(g) == 0:
return pd.Series(
dict(
n_bins_used=0,
n_samples_total=0,
n_positives_total=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
n_bins_used = int(g["age_bin_id"].nunique())
n_samples_total = int(g["n_samples"].sum())
n_positives_total = int(g["n_positives"].sum())
if not weighted:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float(g["auc"].mean()),
auprc=float(g["auprc"].mean()),
recall_at_K=float(g["recall_at_K"].mean()),
precision_at_K=float(g["precision_at_K"].mean()),
brier_score=float(g["brier_score"].mean()),
)
)
w = g["n_samples"].to_numpy(dtype=float)
w_sum = float(w.sum())
if w_sum <= 0.0:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
def _wavg(col: str) -> float:
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=_wavg("auc"),
auprc=_wavg("auprc"),
recall_at_K=_wavg("recall_at_K"),
precision_at_K=_wavg("precision_at_K"),
brier_score=_wavg("brier_score"),
)
)
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
df_mc_macro = (
df_by_bin.groupby(group_keys)
.apply(lambda g: _bin_aggregate(g, weighted=False))
.reset_index()
)
df_mc_macro["agg_type"] = "macro"
df_mc_weighted = (
df_by_bin.groupby(group_keys)
.apply(lambda g: _bin_aggregate(g, weighted=True))
.reset_index()
)
df_mc_weighted["agg_type"] = "weighted"
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
# Then average over MC repetitions.
df_agg = (
df_mc_binagg.groupby(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
)
.agg(
n_mc=("mc_idx", "nunique"),
n_bins_used_mean=("n_bins_used", "mean"),
n_samples_total_mean=("n_samples_total", "mean"),
n_positives_total_mean=("n_positives_total", "mean"),
auc_mean=("auc", "mean"),
auc_std=("auc", "std"),
auprc_mean=("auprc", "mean"),
auprc_std=("auprc", "std"),
recall_at_K_mean=("recall_at_K", "mean"),
recall_at_K_std=("recall_at_K", "std"),
precision_at_K_mean=("precision_at_K", "mean"),
precision_at_K_std=("precision_at_K", "std"),
brier_score_mean=("brier_score", "mean"),
brier_score_std=("brier_score", "std"),
)
.sort_values(
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
ignore_index=True,
)
)
return df_by_bin, df_agg