Files
DeepHealth/evaluation_age_time_dependent.py

466 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from utils import (
multi_hot_ever_within_horizon,
multi_hot_selected_causes_within_horizon,
sample_context_in_fixed_age_bin,
)
def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Compute ROC AUC for binary labels with tie-aware ranking.
Returns NaN if y_true has no positives or no negatives.
Uses the MannWhitney U statistic with average ranks for ties.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
n_neg = n - n_pos
if n_pos == 0 or n_neg == 0:
return float("nan")
# Rank scores ascending, average ranks for ties.
order = np.argsort(y_score, kind="mergesort")
sorted_scores = y_score[order]
ranks = np.empty(n, dtype=float)
i = 0
# 1-based ranks
while i < n:
j = i + 1
while j < n and sorted_scores[j] == sorted_scores[i]:
j += 1
avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j
ranks[order[i:j]] = avg_rank
i = j
sum_ranks_pos = float(ranks[y_true].sum())
u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0)
return float(u / (n_pos * n_neg))
def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Average precision (area under PR curve using step-wise interpolation).
Returns NaN if no positives.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan")
n_pos = int(y_true.sum())
if n_pos == 0:
return float("nan")
order = np.argsort(-y_score, kind="mergesort")
y = y_true[order]
tp = np.cumsum(y).astype(float)
fp = np.cumsum(~y).astype(float)
precision = tp / np.maximum(tp + fp, 1.0)
# AP = sum over each positive of precision at that point / n_pos
# (equivalent to ∑ Δrecall * precision)
ap = float(np.sum(precision[y]) / n_pos)
# handle potential tiny numerical overshoots
return float(max(0.0, min(1.0, ap)))
def _precision_recall_at_k_percent(
y_true: np.ndarray,
y_score: np.ndarray,
k_percent: float,
) -> Tuple[float, float]:
"""Precision@K% and Recall@K% for binary labels.
Returns (precision, recall). Returns NaN for recall if no positives.
Returns NaN for precision if k leads to 0 selected.
"""
y_true = np.asarray(y_true).astype(bool)
y_score = np.asarray(y_score).astype(float)
n = y_true.size
if n == 0:
return float("nan"), float("nan")
n_pos = int(y_true.sum())
k = int(math.ceil((float(k_percent) / 100.0) * n))
if k <= 0:
return float("nan"), float("nan")
order = np.argsort(-y_score, kind="mergesort")
top = order[:k]
tp_top = int(y_true[top].sum())
precision = tp_top / k
recall = float("nan") if n_pos == 0 else (tp_top / n_pos)
return float(precision), float(recall)
@dataclass
class EvalAgeConfig:
horizons_years: Sequence[float]
age_bins: Sequence[Tuple[float, float]]
topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0)
n_mc: int = 5
seed: int = 0
cause_ids: Optional[Sequence[int]] = None
@torch.no_grad()
def evaluate_time_dependent_age_bins(
model: torch.nn.Module,
head: torch.nn.Module,
criterion,
dataloader: torch.utils.data.DataLoader,
n_disease: int,
cfg: EvalAgeConfig,
device: str | torch.device,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Delphi-2M-style age-bin evaluation with strict horizon alignment.
Semantics (strict): for each (MC, horizon tau, age bin) we independently:
- build the eligible token set within that bin
- enforce follow-up coverage: t_ctx + tau <= t_end
- randomly sample exactly one token per individual within the bin (de-dup)
- recompute context representations and predictions for that (tau, bin)
Returns:
df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id)
df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted}
"""
device = torch.device(device)
model.eval()
head.eval()
horizons_years = [float(x) for x in cfg.horizons_years]
if len(horizons_years) == 0:
raise ValueError("cfg.horizons_years must be non-empty")
age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins]
if len(age_bins) == 0:
raise ValueError("cfg.age_bins must be non-empty")
for (a, b) in age_bins:
if not (b > a):
raise ValueError(
f"age_bins must be (low, high) with high>low; got {(a, b)}")
topk_percents = [float(x) for x in cfg.topk_percents]
if len(topk_percents) == 0:
raise ValueError("cfg.topk_percents must be non-empty")
if any((p <= 0.0 or p > 100.0) for p in topk_percents):
raise ValueError(
f"All topk_percents must be in (0,100]; got {topk_percents}")
if int(cfg.n_mc) <= 0:
raise ValueError("cfg.n_mc must be >= 1")
if cfg.cause_ids is None:
cause_ids = None
n_causes_eval = int(n_disease)
else:
cause_ids = torch.tensor(
list(cfg.cause_ids), dtype=torch.long, device=device)
n_causes_eval = int(cause_ids.numel())
# Storage: (mc, h, bin) -> list of arrays
y_true: List[List[List[List[np.ndarray]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
y_pred: List[List[List[List[np.ndarray]]]] = [
[[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))]
for _ in range(int(cfg.n_mc))
]
for mc_idx in range(int(cfg.n_mc)):
# tqdm over batches; include MC idx in description.
for batch_idx, batch in enumerate(
tqdm(dataloader,
desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch")
):
event_seq, time_seq, cont_feats, cate_feats, sexes = batch
event_seq = event_seq.to(device)
time_seq = time_seq.to(device)
cont_feats = cont_feats.to(device)
cate_feats = cate_feats.to(device)
sexes = sexes.to(device)
B = int(event_seq.size(0))
b = torch.arange(B, device=device)
for tau_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
# Diversify RNG stream across MC/tau/bin/batch to reduce correlation.
seed = (
int(cfg.seed)
+ (100_000 * int(mc_idx))
+ (1_000 * int(tau_idx))
+ (10 * int(bin_idx))
+ int(batch_idx)
)
keep, t_ctx = sample_context_in_fixed_age_bin(
event_seq=event_seq,
time_seq=time_seq,
tau_years=float(tau_y),
age_bin=(float(a_lo), float(a_hi)),
seed=seed,
)
if not keep.any():
continue
# Strict bin-specific prediction: recompute representations and logits per (tau, bin).
h = model(event_seq, time_seq, sexes,
cont_feats, cate_feats) # (B,L,D)
c = h[b, t_ctx]
logits = head(c)
cifs = criterion.calculate_cifs(
logits, taus=torch.tensor(float(tau_y), device=device)
)
if cifs.ndim != 2:
raise ValueError(
"criterion.calculate_cifs must return (B,K) for scalar tau; "
f"got shape={tuple(cifs.shape)}"
)
if cause_ids is None:
y = multi_hot_ever_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
n_disease=n_disease,
)
preds = cifs
else:
y = multi_hot_selected_causes_within_horizon(
event_seq=event_seq,
time_seq=time_seq,
t_ctx=t_ctx,
tau_years=float(tau_y),
cause_ids=cause_ids,
n_disease=n_disease,
)
preds = cifs.index_select(dim=1, index=cause_ids)
y_true[mc_idx][tau_idx][bin_idx].append(
y[keep].detach().to(torch.bool).cpu().numpy()
)
y_pred[mc_idx][tau_idx][bin_idx].append(
preds[keep].detach().to(torch.float32).cpu().numpy()
)
rows_by_bin: List[Dict[str, float | int]] = []
for mc_idx in range(int(cfg.n_mc)):
for h_idx, tau_y in enumerate(horizons_years):
for bin_idx, (a_lo, a_hi) in enumerate(age_bins):
if len(y_true[mc_idx][h_idx][bin_idx]) == 0:
# No samples in this bin for this (mc, tau)
for cause_k in range(n_causes_eval):
cause_id = int(cause_k) if cause_ids is None else int(
cfg.cause_ids[cause_k])
for k_percent in topk_percents:
rows_by_bin.append(
dict(
mc_idx=mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
cause_id=cause_id,
n_samples=0,
n_positives=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
continue
yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0)
pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0)
if yb.shape != pb.shape:
raise ValueError(
f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}"
)
n_samples = int(yb.shape[0])
for cause_k in range(n_causes_eval):
yk = yb[:, cause_k]
pk = pb[:, cause_k]
n_pos = int(yk.sum())
auc = _binary_roc_auc(yk, pk)
auprc = _average_precision(yk, pk)
brier = float(np.mean(
(yk.astype(float) - pk.astype(float)) ** 2)) if n_samples > 0 else float("nan")
cause_id = int(cause_k) if cause_ids is None else int(
cfg.cause_ids[cause_k])
for k_percent in topk_percents:
precision_k, recall_k = _precision_recall_at_k_percent(
yk, pk, float(k_percent))
rows_by_bin.append(
dict(
mc_idx=mc_idx,
age_bin_id=bin_idx,
age_bin_low=float(a_lo),
age_bin_high=float(a_hi),
horizon_tau=float(tau_y),
topk_percent=float(k_percent),
cause_id=cause_id,
n_samples=n_samples,
n_positives=n_pos,
auc=float(auc),
auprc=float(auprc),
recall_at_K=float(recall_k),
precision_at_K=float(precision_k),
brier_score=float(brier),
)
)
df_by_bin = pd.DataFrame(rows_by_bin)
def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series:
g = group[group["n_samples"] > 0]
if len(g) == 0:
return pd.Series(
dict(
n_bins_used=0,
n_samples_total=0,
n_positives_total=0,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
n_bins_used = int(g["age_bin_id"].nunique())
n_samples_total = int(g["n_samples"].sum())
n_positives_total = int(g["n_positives"].sum())
if not weighted:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float(g["auc"].mean()),
auprc=float(g["auprc"].mean()),
recall_at_K=float(g["recall_at_K"].mean()),
precision_at_K=float(g["precision_at_K"].mean()),
brier_score=float(g["brier_score"].mean()),
)
)
w = g["n_samples"].to_numpy(dtype=float)
w_sum = float(w.sum())
if w_sum <= 0.0:
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=float("nan"),
auprc=float("nan"),
recall_at_K=float("nan"),
precision_at_K=float("nan"),
brier_score=float("nan"),
)
)
def _wavg(col: str) -> float:
return float(np.average(g[col].to_numpy(dtype=float), weights=w))
return pd.Series(
dict(
n_bins_used=n_bins_used,
n_samples_total=n_samples_total,
n_positives_total=n_positives_total,
auc=_wavg("auc"),
auprc=_wavg("auprc"),
recall_at_K=_wavg("recall_at_K"),
precision_at_K=_wavg("precision_at_K"),
brier_score=_wavg("brier_score"),
)
)
group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"]
df_mc_macro = (
df_by_bin.groupby(group_keys)
.apply(lambda g: _bin_aggregate(g, weighted=False))
.reset_index()
)
df_mc_macro["agg_type"] = "macro"
df_mc_weighted = (
df_by_bin.groupby(group_keys)
.apply(lambda g: _bin_aggregate(g, weighted=True))
.reset_index()
)
df_mc_weighted["agg_type"] = "weighted"
df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True)
# Then average over MC repetitions.
df_agg = (
df_mc_binagg.groupby(
["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False
)
.agg(
n_mc=("mc_idx", "nunique"),
n_bins_used_mean=("n_bins_used", "mean"),
n_samples_total_mean=("n_samples_total", "mean"),
n_positives_total_mean=("n_positives_total", "mean"),
auc_mean=("auc", "mean"),
auc_std=("auc", "std"),
auprc_mean=("auprc", "mean"),
auprc_std=("auprc", "std"),
recall_at_K_mean=("recall_at_K", "mean"),
recall_at_K_std=("recall_at_K", "std"),
precision_at_K_mean=("precision_at_K", "mean"),
precision_at_K_std=("precision_at_K", "std"),
brier_score_mean=("brier_score", "mean"),
brier_score_std=("brier_score", "std"),
)
.sort_values(
["agg_type", "horizon_tau", "topk_percent", "cause_id"],
ignore_index=True,
)
)
return df_by_bin, df_agg