from __future__ import annotations import math from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Tuple import numpy as np import pandas as pd import torch try: from tqdm import tqdm except Exception: # pragma: no cover def tqdm(x, **kwargs): return x from utils import ( multi_hot_ever_within_horizon, multi_hot_selected_causes_within_horizon, sample_context_in_fixed_age_bin, ) def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float: """Compute ROC AUC for binary labels with tie-aware ranking. Returns NaN if y_true has no positives or no negatives. Uses the Mann–Whitney U statistic with average ranks for ties. """ y_true = np.asarray(y_true).astype(bool) y_score = np.asarray(y_score).astype(float) n = y_true.size if n == 0: return float("nan") n_pos = int(y_true.sum()) n_neg = n - n_pos if n_pos == 0 or n_neg == 0: return float("nan") # Rank scores ascending, average ranks for ties. order = np.argsort(y_score, kind="mergesort") sorted_scores = y_score[order] ranks = np.empty(n, dtype=float) i = 0 # 1-based ranks while i < n: j = i + 1 while j < n and sorted_scores[j] == sorted_scores[i]: j += 1 avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j ranks[order[i:j]] = avg_rank i = j sum_ranks_pos = float(ranks[y_true].sum()) u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0) return float(u / (n_pos * n_neg)) def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float: """Average precision (area under PR curve using step-wise interpolation). Returns NaN if no positives. """ y_true = np.asarray(y_true).astype(bool) y_score = np.asarray(y_score).astype(float) n = y_true.size if n == 0: return float("nan") n_pos = int(y_true.sum()) if n_pos == 0: return float("nan") order = np.argsort(-y_score, kind="mergesort") y = y_true[order] tp = np.cumsum(y).astype(float) fp = np.cumsum(~y).astype(float) precision = tp / np.maximum(tp + fp, 1.0) # AP = sum over each positive of precision at that point / n_pos # (equivalent to ∑ Δrecall * precision) ap = float(np.sum(precision[y]) / n_pos) # handle potential tiny numerical overshoots return float(max(0.0, min(1.0, ap))) def _precision_recall_at_k_percent( y_true: np.ndarray, y_score: np.ndarray, k_percent: float, ) -> Tuple[float, float]: """Precision@K% and Recall@K% for binary labels. Returns (precision, recall). Returns NaN for recall if no positives. Returns NaN for precision if k leads to 0 selected. """ y_true = np.asarray(y_true).astype(bool) y_score = np.asarray(y_score).astype(float) n = y_true.size if n == 0: return float("nan"), float("nan") n_pos = int(y_true.sum()) k = int(math.ceil((float(k_percent) / 100.0) * n)) if k <= 0: return float("nan"), float("nan") order = np.argsort(-y_score, kind="mergesort") top = order[:k] tp_top = int(y_true[top].sum()) precision = tp_top / k recall = float("nan") if n_pos == 0 else (tp_top / n_pos) return float(precision), float(recall) @dataclass class EvalAgeConfig: horizons_years: Sequence[float] age_bins: Sequence[Tuple[float, float]] topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0) n_mc: int = 5 seed: int = 0 cause_ids: Optional[Sequence[int]] = None @torch.no_grad() def evaluate_time_dependent_age_bins( model: torch.nn.Module, head: torch.nn.Module, criterion, dataloader: torch.utils.data.DataLoader, n_disease: int, cfg: EvalAgeConfig, device: str | torch.device, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """Delphi-2M-style age-bin evaluation with strict horizon alignment. Semantics (strict): for each (MC, horizon tau, age bin) we independently: - build the eligible token set within that bin - enforce follow-up coverage: t_ctx + tau <= t_end - randomly sample exactly one token per individual within the bin (de-dup) - recompute context representations and predictions for that (tau, bin) Returns: df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id) df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted} """ device = torch.device(device) model.eval() head.eval() horizons_years = [float(x) for x in cfg.horizons_years] if len(horizons_years) == 0: raise ValueError("cfg.horizons_years must be non-empty") age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins] if len(age_bins) == 0: raise ValueError("cfg.age_bins must be non-empty") for (a, b) in age_bins: if not (b > a): raise ValueError( f"age_bins must be (low, high) with high>low; got {(a, b)}") topk_percents = [float(x) for x in cfg.topk_percents] if len(topk_percents) == 0: raise ValueError("cfg.topk_percents must be non-empty") if any((p <= 0.0 or p > 100.0) for p in topk_percents): raise ValueError( f"All topk_percents must be in (0,100]; got {topk_percents}") if int(cfg.n_mc) <= 0: raise ValueError("cfg.n_mc must be >= 1") if cfg.cause_ids is None: cause_ids = None n_causes_eval = int(n_disease) else: cause_ids = torch.tensor( list(cfg.cause_ids), dtype=torch.long, device=device) n_causes_eval = int(cause_ids.numel()) # Storage: (mc, h, bin) -> list of arrays y_true: List[List[List[List[np.ndarray]]]] = [ [[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))] for _ in range(int(cfg.n_mc)) ] y_pred: List[List[List[List[np.ndarray]]]] = [ [[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))] for _ in range(int(cfg.n_mc)) ] for mc_idx in range(int(cfg.n_mc)): # tqdm over batches; include MC idx in description. for batch_idx, batch in enumerate( tqdm(dataloader, desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch") ): event_seq, time_seq, cont_feats, cate_feats, sexes = batch event_seq = event_seq.to(device) time_seq = time_seq.to(device) cont_feats = cont_feats.to(device) cate_feats = cate_feats.to(device) sexes = sexes.to(device) B = int(event_seq.size(0)) b = torch.arange(B, device=device) for tau_idx, tau_y in enumerate(horizons_years): for bin_idx, (a_lo, a_hi) in enumerate(age_bins): # Diversify RNG stream across MC/tau/bin/batch to reduce correlation. seed = ( int(cfg.seed) + (100_000 * int(mc_idx)) + (1_000 * int(tau_idx)) + (10 * int(bin_idx)) + int(batch_idx) ) keep, t_ctx = sample_context_in_fixed_age_bin( event_seq=event_seq, time_seq=time_seq, tau_years=float(tau_y), age_bin=(float(a_lo), float(a_hi)), seed=seed, ) if not keep.any(): continue # Strict bin-specific prediction: recompute representations and logits per (tau, bin). h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D) c = h[b, t_ctx] logits = head(c) cifs = criterion.calculate_cifs( logits, taus=torch.tensor(float(tau_y), device=device) ) if cifs.ndim != 2: raise ValueError( "criterion.calculate_cifs must return (B,K) for scalar tau; " f"got shape={tuple(cifs.shape)}" ) if cause_ids is None: y = multi_hot_ever_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau_y), n_disease=n_disease, ) preds = cifs else: y = multi_hot_selected_causes_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau_y), cause_ids=cause_ids, n_disease=n_disease, ) preds = cifs.index_select(dim=1, index=cause_ids) y_true[mc_idx][tau_idx][bin_idx].append( y[keep].detach().to(torch.bool).cpu().numpy() ) y_pred[mc_idx][tau_idx][bin_idx].append( preds[keep].detach().to(torch.float32).cpu().numpy() ) rows_by_bin: List[Dict[str, float | int]] = [] for mc_idx in range(int(cfg.n_mc)): for h_idx, tau_y in enumerate(horizons_years): for bin_idx, (a_lo, a_hi) in enumerate(age_bins): if len(y_true[mc_idx][h_idx][bin_idx]) == 0: # No samples in this bin for this (mc, tau) for cause_k in range(n_causes_eval): cause_id = int(cause_k) if cause_ids is None else int( cfg.cause_ids[cause_k]) for k_percent in topk_percents: rows_by_bin.append( dict( mc_idx=mc_idx, age_bin_id=bin_idx, age_bin_low=float(a_lo), age_bin_high=float(a_hi), horizon_tau=float(tau_y), topk_percent=float(k_percent), cause_id=cause_id, n_samples=0, n_positives=0, auc=float("nan"), auprc=float("nan"), recall_at_K=float("nan"), precision_at_K=float("nan"), brier_score=float("nan"), ) ) continue yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0) pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0) if yb.shape != pb.shape: raise ValueError( f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}" ) n_samples = int(yb.shape[0]) for cause_k in range(n_causes_eval): yk = yb[:, cause_k] pk = pb[:, cause_k] n_pos = int(yk.sum()) auc = _binary_roc_auc(yk, pk) auprc = _average_precision(yk, pk) brier = float(np.mean( (yk.astype(float) - pk.astype(float)) ** 2)) if n_samples > 0 else float("nan") cause_id = int(cause_k) if cause_ids is None else int( cfg.cause_ids[cause_k]) for k_percent in topk_percents: precision_k, recall_k = _precision_recall_at_k_percent( yk, pk, float(k_percent)) rows_by_bin.append( dict( mc_idx=mc_idx, age_bin_id=bin_idx, age_bin_low=float(a_lo), age_bin_high=float(a_hi), horizon_tau=float(tau_y), topk_percent=float(k_percent), cause_id=cause_id, n_samples=n_samples, n_positives=n_pos, auc=float(auc), auprc=float(auprc), recall_at_K=float(recall_k), precision_at_K=float(precision_k), brier_score=float(brier), ) ) df_by_bin = pd.DataFrame(rows_by_bin) def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series: g = group[group["n_samples"] > 0] if len(g) == 0: return pd.Series( dict( n_bins_used=0, n_samples_total=0, n_positives_total=0, auc=float("nan"), auprc=float("nan"), recall_at_K=float("nan"), precision_at_K=float("nan"), brier_score=float("nan"), ) ) n_bins_used = int(g["age_bin_id"].nunique()) n_samples_total = int(g["n_samples"].sum()) n_positives_total = int(g["n_positives"].sum()) if not weighted: return pd.Series( dict( n_bins_used=n_bins_used, n_samples_total=n_samples_total, n_positives_total=n_positives_total, auc=float(g["auc"].mean()), auprc=float(g["auprc"].mean()), recall_at_K=float(g["recall_at_K"].mean()), precision_at_K=float(g["precision_at_K"].mean()), brier_score=float(g["brier_score"].mean()), ) ) w = g["n_samples"].to_numpy(dtype=float) w_sum = float(w.sum()) if w_sum <= 0.0: return pd.Series( dict( n_bins_used=n_bins_used, n_samples_total=n_samples_total, n_positives_total=n_positives_total, auc=float("nan"), auprc=float("nan"), recall_at_K=float("nan"), precision_at_K=float("nan"), brier_score=float("nan"), ) ) def _wavg(col: str) -> float: return float(np.average(g[col].to_numpy(dtype=float), weights=w)) return pd.Series( dict( n_bins_used=n_bins_used, n_samples_total=n_samples_total, n_positives_total=n_positives_total, auc=_wavg("auc"), auprc=_wavg("auprc"), recall_at_K=_wavg("recall_at_K"), precision_at_K=_wavg("precision_at_K"), brier_score=_wavg("brier_score"), ) ) group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"] df_mc_macro = ( df_by_bin.groupby(group_keys) .apply(lambda g: _bin_aggregate(g, weighted=False)) .reset_index() ) df_mc_macro["agg_type"] = "macro" df_mc_weighted = ( df_by_bin.groupby(group_keys) .apply(lambda g: _bin_aggregate(g, weighted=True)) .reset_index() ) df_mc_weighted["agg_type"] = "weighted" df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True) # Then average over MC repetitions. df_agg = ( df_mc_binagg.groupby( ["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False ) .agg( n_mc=("mc_idx", "nunique"), n_bins_used_mean=("n_bins_used", "mean"), n_samples_total_mean=("n_samples_total", "mean"), n_positives_total_mean=("n_positives_total", "mean"), auc_mean=("auc", "mean"), auc_std=("auc", "std"), auprc_mean=("auprc", "mean"), auprc_std=("auprc", "std"), recall_at_K_mean=("recall_at_K", "mean"), recall_at_K_std=("recall_at_K", "std"), precision_at_K_mean=("precision_at_K", "mean"), precision_at_K_std=("precision_at_K", "std"), brier_score_mean=("brier_score", "mean"), brier_score_std=("brier_score", "std"), ) .sort_values( ["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True, ) ) return df_by_bin, df_agg