from __future__ import annotations import math from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Tuple import numpy as np import pandas as pd import torch from tqdm import tqdm from utils import ( multi_hot_ever_within_horizon, multi_hot_selected_causes_within_horizon, sample_context_in_fixed_age_bin, ) def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame: """Aggregate per-bin age evaluation results. Produces both: - macro: unweighted mean over bins with n_samples>0 - weighted: weighted mean over bins using weights=n_samples Then aggregates across MC repetitions (mean/std). Requires df_by_bin to include: mc_idx, horizon_tau, topk_percent, cause_id, age_bin_id, n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score Returns: DataFrame keyed by (agg_type, horizon_tau, topk_percent, cause_id) """ if df_by_bin is None or len(df_by_bin) == 0: return pd.DataFrame( columns=[ "agg_type", "horizon_tau", "topk_percent", "cause_id", "n_mc", "n_bins_used_mean", "n_samples_total_mean", "n_positives_total_mean", "auc_mean", "auc_std", "auprc_mean", "auprc_std", "recall_at_K_mean", "recall_at_K_std", "precision_at_K_mean", "precision_at_K_std", "brier_score_mean", "brier_score_std", ] ) def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series: g = group[group["n_samples"] > 0] if len(g) == 0: return pd.Series( dict( n_bins_used=0, n_samples_total=0, n_positives_total=0, auc=float("nan"), auprc=float("nan"), recall_at_K=float("nan"), precision_at_K=float("nan"), brier_score=float("nan"), ) ) n_bins_used = int(g["age_bin_id"].nunique()) n_samples_total = int(g["n_samples"].sum()) n_positives_total = int(g["n_positives"].sum()) if not weighted: return pd.Series( dict( n_bins_used=n_bins_used, n_samples_total=n_samples_total, n_positives_total=n_positives_total, auc=float(g["auc"].mean()), auprc=float(g["auprc"].mean()), recall_at_K=float(g["recall_at_K"].mean()), precision_at_K=float(g["precision_at_K"].mean()), brier_score=float(g["brier_score"].mean()), ) ) w = g["n_samples"].to_numpy(dtype=float) w_sum = float(w.sum()) if w_sum <= 0.0: return pd.Series( dict( n_bins_used=n_bins_used, n_samples_total=n_samples_total, n_positives_total=n_positives_total, auc=float("nan"), auprc=float("nan"), recall_at_K=float("nan"), precision_at_K=float("nan"), brier_score=float("nan"), ) ) def _wavg(col: str) -> float: return float(np.average(g[col].to_numpy(dtype=float), weights=w)) return pd.Series( dict( n_bins_used=n_bins_used, n_samples_total=n_samples_total, n_positives_total=n_positives_total, auc=_wavg("auc"), auprc=_wavg("auprc"), recall_at_K=_wavg("recall_at_K"), precision_at_K=_wavg("precision_at_K"), brier_score=_wavg("brier_score"), ) ) group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"] df_mc_macro = ( df_by_bin.groupby(group_keys) .apply(lambda g: _bin_aggregate(g, weighted=False)) .reset_index() ) df_mc_macro["agg_type"] = "macro" df_mc_weighted = ( df_by_bin.groupby(group_keys) .apply(lambda g: _bin_aggregate(g, weighted=True)) .reset_index() ) df_mc_weighted["agg_type"] = "weighted" df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True) df_agg = ( df_mc_binagg.groupby( ["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False ) .agg( n_mc=("mc_idx", "nunique"), n_bins_used_mean=("n_bins_used", "mean"), n_samples_total_mean=("n_samples_total", "mean"), n_positives_total_mean=("n_positives_total", "mean"), auc_mean=("auc", "mean"), auc_std=("auc", "std"), auprc_mean=("auprc", "mean"), auprc_std=("auprc", "std"), recall_at_K_mean=("recall_at_K", "mean"), recall_at_K_std=("recall_at_K", "std"), precision_at_K_mean=("precision_at_K", "mean"), precision_at_K_std=("precision_at_K", "std"), brier_score_mean=("brier_score", "mean"), brier_score_std=("brier_score", "std"), ) .sort_values( ["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True, ) ) return df_agg def _binary_roc_auc(y_true: np.ndarray, y_score: np.ndarray) -> float: """Compute ROC AUC for binary labels with tie-aware ranking. Returns NaN if y_true has no positives or no negatives. Uses the Mann–Whitney U statistic with average ranks for ties. """ y_true = np.asarray(y_true).astype(bool) y_score = np.asarray(y_score).astype(float) n = y_true.size if n == 0: return float("nan") n_pos = int(y_true.sum()) n_neg = n - n_pos if n_pos == 0 or n_neg == 0: return float("nan") # Rank scores ascending, average ranks for ties. order = np.argsort(y_score, kind="mergesort") sorted_scores = y_score[order] ranks = np.empty(n, dtype=float) i = 0 # 1-based ranks while i < n: j = i + 1 while j < n and sorted_scores[j] == sorted_scores[i]: j += 1 avg_rank = 0.5 * ((i + 1) + j) # ranks i+1 .. j ranks[order[i:j]] = avg_rank i = j sum_ranks_pos = float(ranks[y_true].sum()) u = sum_ranks_pos - (n_pos * (n_pos + 1) / 2.0) return float(u / (n_pos * n_neg)) def _average_precision(y_true: np.ndarray, y_score: np.ndarray) -> float: """Average precision (area under PR curve using step-wise interpolation). Returns NaN if no positives. """ y_true = np.asarray(y_true).astype(bool) y_score = np.asarray(y_score).astype(float) n = y_true.size if n == 0: return float("nan") n_pos = int(y_true.sum()) if n_pos == 0: return float("nan") order = np.argsort(-y_score, kind="mergesort") y = y_true[order] tp = np.cumsum(y).astype(float) fp = np.cumsum(~y).astype(float) precision = tp / np.maximum(tp + fp, 1.0) # AP = sum over each positive of precision at that point / n_pos # (equivalent to ∑ Δrecall * precision) ap = float(np.sum(precision[y]) / n_pos) # handle potential tiny numerical overshoots return float(max(0.0, min(1.0, ap))) def _precision_recall_at_k_percent( y_true: np.ndarray, y_score: np.ndarray, k_percent: float, ) -> Tuple[float, float]: """Precision@K% and Recall@K% for binary labels. Returns (precision, recall). Returns NaN for recall if no positives. Returns NaN for precision if k leads to 0 selected. """ y_true = np.asarray(y_true).astype(bool) y_score = np.asarray(y_score).astype(float) n = y_true.size if n == 0: return float("nan"), float("nan") n_pos = int(y_true.sum()) k = int(math.ceil((float(k_percent) / 100.0) * n)) if k <= 0: return float("nan"), float("nan") order = np.argsort(-y_score, kind="mergesort") top = order[:k] tp_top = int(y_true[top].sum()) precision = tp_top / k recall = float("nan") if n_pos == 0 else (tp_top / n_pos) return float(precision), float(recall) @dataclass class EvalAgeConfig: horizons_years: Sequence[float] age_bins: Sequence[Tuple[float, float]] topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0) n_mc: int = 5 seed: int = 0 cause_ids: Optional[Sequence[int]] = None @torch.no_grad() def evaluate_time_dependent_age_bins( model: torch.nn.Module, head: torch.nn.Module, criterion, dataloader: torch.utils.data.DataLoader, n_disease: int, cfg: EvalAgeConfig, device: str | torch.device, mc_offset: int = 0, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """Delphi-2M-style age-bin evaluation with strict horizon alignment. Semantics (strict): for each (MC, horizon tau, age bin) we independently: - build the eligible token set within that bin - enforce follow-up coverage: t_ctx + tau <= t_end - randomly sample exactly one token per individual within the bin (de-dup) - recompute context representations and predictions for that (tau, bin) Returns: df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id) df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted} """ device = torch.device(device) model.eval() head.eval() horizons_years = [float(x) for x in cfg.horizons_years] if len(horizons_years) == 0: raise ValueError("cfg.horizons_years must be non-empty") age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins] if len(age_bins) == 0: raise ValueError("cfg.age_bins must be non-empty") for (a, b) in age_bins: if not (b > a): raise ValueError( f"age_bins must be (low, high) with high>low; got {(a, b)}") topk_percents = [float(x) for x in cfg.topk_percents] if len(topk_percents) == 0: raise ValueError("cfg.topk_percents must be non-empty") if any((p <= 0.0 or p > 100.0) for p in topk_percents): raise ValueError( f"All topk_percents must be in (0,100]; got {topk_percents}") if int(cfg.n_mc) <= 0: raise ValueError("cfg.n_mc must be >= 1") if cfg.cause_ids is None: cause_ids = None n_causes_eval = int(n_disease) else: cause_ids = torch.tensor( list(cfg.cause_ids), dtype=torch.long, device=device) n_causes_eval = int(cause_ids.numel()) # Storage: (mc, h, bin) -> list of arrays y_true: List[List[List[List[np.ndarray]]]] = [ [[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))] for _ in range(int(cfg.n_mc)) ] y_pred: List[List[List[List[np.ndarray]]]] = [ [[[] for _ in range(len(age_bins))] for _ in range(len(horizons_years))] for _ in range(int(cfg.n_mc)) ] for mc_idx in range(int(cfg.n_mc)): global_mc_idx = int(mc_offset) + int(mc_idx) # tqdm over batches; include MC idx in description. for batch_idx, batch in enumerate( tqdm(dataloader, desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch") ): event_seq, time_seq, cont_feats, cate_feats, sexes = batch event_seq = event_seq.to(device) time_seq = time_seq.to(device) cont_feats = cont_feats.to(device) cate_feats = cate_feats.to(device) sexes = sexes.to(device) B = int(event_seq.size(0)) b = torch.arange(B, device=device) for tau_idx, tau_y in enumerate(horizons_years): for bin_idx, (a_lo, a_hi) in enumerate(age_bins): # Diversify RNG stream across MC/tau/bin/batch to reduce correlation. seed = ( int(cfg.seed) + (100_000 * int(global_mc_idx)) + (1_000 * int(tau_idx)) + (10 * int(bin_idx)) + int(batch_idx) ) keep, t_ctx = sample_context_in_fixed_age_bin( event_seq=event_seq, time_seq=time_seq, tau_years=float(tau_y), age_bin=(float(a_lo), float(a_hi)), seed=seed, ) if not keep.any(): continue # Strict bin-specific prediction: recompute representations and logits per (tau, bin). h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D) c = h[b, t_ctx] logits = head(c) cifs = criterion.calculate_cifs( logits, taus=torch.tensor(float(tau_y), device=device) ) if cifs.ndim != 2: raise ValueError( "criterion.calculate_cifs must return (B,K) for scalar tau; " f"got shape={tuple(cifs.shape)}" ) if cause_ids is None: y = multi_hot_ever_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau_y), n_disease=n_disease, ) preds = cifs else: y = multi_hot_selected_causes_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau_y), cause_ids=cause_ids, n_disease=n_disease, ) preds = cifs.index_select(dim=1, index=cause_ids) y_true[mc_idx][tau_idx][bin_idx].append( y[keep].detach().to(torch.bool).cpu().numpy() ) y_pred[mc_idx][tau_idx][bin_idx].append( preds[keep].detach().to(torch.float32).cpu().numpy() ) rows_by_bin: List[Dict[str, float | int]] = [] for mc_idx in range(int(cfg.n_mc)): global_mc_idx = int(mc_offset) + int(mc_idx) for h_idx, tau_y in enumerate(horizons_years): for bin_idx, (a_lo, a_hi) in enumerate(age_bins): if len(y_true[mc_idx][h_idx][bin_idx]) == 0: # No samples in this bin for this (mc, tau) for cause_k in range(n_causes_eval): cause_id = int(cause_k) if cause_ids is None else int( cfg.cause_ids[cause_k]) for k_percent in topk_percents: rows_by_bin.append( dict( mc_idx=global_mc_idx, age_bin_id=bin_idx, age_bin_low=float(a_lo), age_bin_high=float(a_hi), horizon_tau=float(tau_y), topk_percent=float(k_percent), cause_id=cause_id, n_samples=0, n_positives=0, auc=float("nan"), auprc=float("nan"), recall_at_K=float("nan"), precision_at_K=float("nan"), brier_score=float("nan"), ) ) continue yb = np.concatenate(y_true[mc_idx][h_idx][bin_idx], axis=0) pb = np.concatenate(y_pred[mc_idx][h_idx][bin_idx], axis=0) if yb.shape != pb.shape: raise ValueError( f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb.shape)} vs p{tuple(pb.shape)}" ) n_samples = int(yb.shape[0]) for cause_k in range(n_causes_eval): yk = yb[:, cause_k] pk = pb[:, cause_k] n_pos = int(yk.sum()) auc = _binary_roc_auc(yk, pk) auprc = _average_precision(yk, pk) brier = float(np.mean( (yk.astype(float) - pk.astype(float)) ** 2)) if n_samples > 0 else float("nan") cause_id = int(cause_k) if cause_ids is None else int( cfg.cause_ids[cause_k]) for k_percent in topk_percents: precision_k, recall_k = _precision_recall_at_k_percent( yk, pk, float(k_percent)) rows_by_bin.append( dict( mc_idx=global_mc_idx, age_bin_id=bin_idx, age_bin_low=float(a_lo), age_bin_high=float(a_hi), horizon_tau=float(tau_y), topk_percent=float(k_percent), cause_id=cause_id, n_samples=n_samples, n_positives=n_pos, auc=float(auc), auprc=float(auprc), recall_at_K=float(recall_k), precision_at_K=float(precision_k), brier_score=float(brier), ) ) df_by_bin = pd.DataFrame(rows_by_bin) df_agg = aggregate_age_bin_results(df_by_bin) return df_by_bin, df_agg