from __future__ import annotations import math from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Tuple import numpy as np import pandas as pd import torch try: from tqdm import tqdm except Exception: # pragma: no cover def tqdm(x, **kwargs): return x from utils import ( multi_hot_ever_within_horizon, multi_hot_selected_causes_within_horizon, sample_context_in_fixed_age_bin, ) from torch_metrics import compute_binary_metrics_torch def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame: """Aggregate per-bin age evaluation results. Produces both: - macro: unweighted mean over bins with n_samples>0 - weighted: weighted mean over bins using weights=n_samples Then aggregates across MC repetitions (mean/std). Requires df_by_bin to include: mc_idx, horizon_tau, topk_percent, cause_id, age_bin_id, n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score Returns: DataFrame keyed by (agg_type, horizon_tau, topk_percent, cause_id) """ if df_by_bin is None or len(df_by_bin) == 0: return pd.DataFrame( columns=[ "agg_type", "horizon_tau", "topk_percent", "cause_id", "n_mc", "n_bins_used_mean", "n_samples_total_mean", "n_positives_total_mean", "auc_mean", "auc_std", "auprc_mean", "auprc_std", "recall_at_K_mean", "recall_at_K_std", "precision_at_K_mean", "precision_at_K_std", "brier_score_mean", "brier_score_std", ] ) def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series: g = group[group["n_samples"] > 0] if len(g) == 0: return pd.Series( dict( n_bins_used=0, n_samples_total=0, n_positives_total=0, auc=float("nan"), auprc=float("nan"), recall_at_K=float("nan"), precision_at_K=float("nan"), brier_score=float("nan"), ) ) n_bins_used = int(g["age_bin_id"].nunique()) n_samples_total = int(g["n_samples"].sum()) n_positives_total = int(g["n_positives"].sum()) if not weighted: return pd.Series( dict( n_bins_used=n_bins_used, n_samples_total=n_samples_total, n_positives_total=n_positives_total, auc=float(g["auc"].mean()), auprc=float(g["auprc"].mean()), recall_at_K=float(g["recall_at_K"].mean()), precision_at_K=float(g["precision_at_K"].mean()), brier_score=float(g["brier_score"].mean()), ) ) w = g["n_samples"].to_numpy(dtype=float) w_sum = float(w.sum()) if w_sum <= 0.0: return pd.Series( dict( n_bins_used=n_bins_used, n_samples_total=n_samples_total, n_positives_total=n_positives_total, auc=float("nan"), auprc=float("nan"), recall_at_K=float("nan"), precision_at_K=float("nan"), brier_score=float("nan"), ) ) def _wavg(col: str) -> float: return float(np.average(g[col].to_numpy(dtype=float), weights=w)) return pd.Series( dict( n_bins_used=n_bins_used, n_samples_total=n_samples_total, n_positives_total=n_positives_total, auc=_wavg("auc"), auprc=_wavg("auprc"), recall_at_K=_wavg("recall_at_K"), precision_at_K=_wavg("precision_at_K"), brier_score=_wavg("brier_score"), ) ) group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"] gb = df_by_bin.groupby(group_keys) try: df_mc_macro = gb.apply( lambda g: _bin_aggregate(g, weighted=False), include_groups=False ).reset_index() except TypeError: # pandas<2.2 (no include_groups) df_mc_macro = gb.apply(lambda g: _bin_aggregate( g, weighted=False)).reset_index() df_mc_macro["agg_type"] = "macro" try: df_mc_weighted = gb.apply( lambda g: _bin_aggregate(g, weighted=True), include_groups=False ).reset_index() except TypeError: # pandas<2.2 (no include_groups) df_mc_weighted = gb.apply( lambda g: _bin_aggregate(g, weighted=True)).reset_index() df_mc_weighted["agg_type"] = "weighted" df_mc_binagg = pd.concat([df_mc_macro, df_mc_weighted], ignore_index=True) df_agg = ( df_mc_binagg.groupby( ["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False ) .agg( n_mc=("mc_idx", "nunique"), n_bins_used_mean=("n_bins_used", "mean"), n_samples_total_mean=("n_samples_total", "mean"), n_positives_total_mean=("n_positives_total", "mean"), auc_mean=("auc", "mean"), auc_std=("auc", "std"), auprc_mean=("auprc", "mean"), auprc_std=("auprc", "std"), recall_at_K_mean=("recall_at_K", "mean"), recall_at_K_std=("recall_at_K", "std"), precision_at_K_mean=("precision_at_K", "mean"), precision_at_K_std=("precision_at_K", "std"), brier_score_mean=("brier_score", "mean"), brier_score_std=("brier_score", "std"), ) .sort_values( ["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True, ) ) return df_agg # NOTE: metric computation is torch/GPU-native in `torch_metrics.py`. # NumPy/Pandas are only used for final CSV formatting/aggregation. @dataclass class EvalAgeConfig: horizons_years: Sequence[float] age_bins: Sequence[Tuple[float, float]] topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0) n_mc: int = 5 seed: int = 0 cause_ids: Optional[Sequence[int]] = None @torch.inference_mode() def evaluate_time_dependent_age_bins( model: torch.nn.Module, head: torch.nn.Module, criterion, dataloader: torch.utils.data.DataLoader, n_disease: int, cfg: EvalAgeConfig, device: str | torch.device, mc_offset: int = 0, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """Delphi-2M-style age-bin evaluation with strict horizon alignment. Semantics (strict): for each (MC, horizon tau, age bin) we independently: - build the eligible token set within that bin - enforce follow-up coverage: t_ctx + tau <= t_end - randomly sample exactly one token per individual within the bin (de-dup) - recompute context representations and predictions for that (tau, bin) Returns: df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id) df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted} """ device = torch.device(device) model.eval() head.eval() horizons_years = [float(x) for x in cfg.horizons_years] if len(horizons_years) == 0: raise ValueError("cfg.horizons_years must be non-empty") age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins] if len(age_bins) == 0: raise ValueError("cfg.age_bins must be non-empty") for (a, b) in age_bins: if not (b > a): raise ValueError( f"age_bins must be (low, high) with high>low; got {(a, b)}") topk_percents = [float(x) for x in cfg.topk_percents] if len(topk_percents) == 0: raise ValueError("cfg.topk_percents must be non-empty") if any((p <= 0.0 or p > 100.0) for p in topk_percents): raise ValueError( f"All topk_percents must be in (0,100]; got {topk_percents}") if int(cfg.n_mc) <= 0: raise ValueError("cfg.n_mc must be >= 1") if cfg.cause_ids is None: cause_ids = None n_causes_eval = int(n_disease) else: cause_ids = torch.tensor( list(cfg.cause_ids), dtype=torch.long, device=device) n_causes_eval = int(cause_ids.numel()) rows_by_bin: List[Dict[str, float | int]] = [] for mc_idx in range(int(cfg.n_mc)): global_mc_idx = int(mc_offset) + int(mc_idx) # Storage for this MC only: (tau, bin) -> list of GPU tensors. # This keeps computations GPU-first while preventing a factor-n_mc # blow-up in GPU memory. y_true_mc: List[List[List[torch.Tensor]]] = [ [[] for _ in range(len(age_bins))] for _ in range(len(horizons_years)) ] y_pred_mc: List[List[List[torch.Tensor]]] = [ [[] for _ in range(len(age_bins))] for _ in range(len(horizons_years)) ] # tqdm over batches; include MC idx in description. for batch_idx, batch in enumerate( tqdm(dataloader, desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch") ): event_seq, time_seq, cont_feats, cate_feats, sexes = batch event_seq = event_seq.to(device) time_seq = time_seq.to(device) cont_feats = cont_feats.to(device) cate_feats = cate_feats.to(device) sexes = sexes.to(device) B = int(event_seq.size(0)) b = torch.arange(B, device=device) # Hoist backbone forward pass: inputs are identical across (tau, age_bin) # within this batch, so this is safe and numerically identical. h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D) for tau_idx, tau_y in enumerate(horizons_years): tau_tensor = torch.tensor(float(tau_y), device=device) for bin_idx, (a_lo, a_hi) in enumerate(age_bins): # Diversify RNG stream across MC/tau/bin/batch to reduce correlation. seed = ( int(cfg.seed) + (100_000 * int(global_mc_idx)) + (1_000 * int(tau_idx)) + (10 * int(bin_idx)) + int(batch_idx) ) keep, t_ctx = sample_context_in_fixed_age_bin( event_seq=event_seq, time_seq=time_seq, tau_years=float(tau_y), age_bin=(float(a_lo), float(a_hi)), seed=seed, ) if not keep.any(): continue # Bin-specific prediction: context indices differ per (tau, bin) # but the backbone features do not. c = h[b, t_ctx] logits = head(c) cifs = criterion.calculate_cifs( logits, taus=tau_tensor ) if cifs.ndim != 2: raise ValueError( "criterion.calculate_cifs must return (B,K) for scalar tau; " f"got shape={tuple(cifs.shape)}" ) if cause_ids is None: y = multi_hot_ever_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau_y), n_disease=n_disease, ) preds = cifs else: y = multi_hot_selected_causes_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau_y), cause_ids=cause_ids, n_disease=n_disease, ) preds = cifs.index_select(dim=1, index=cause_ids) y_true_mc[tau_idx][bin_idx].append( y[keep].detach().to(dtype=torch.bool) ) y_pred_mc[tau_idx][bin_idx].append( preds[keep].detach().to(dtype=torch.float32) ) # Aggregate this MC immediately (frees GPU memory before next MC). for h_idx, tau_y in enumerate(horizons_years): for bin_idx, (a_lo, a_hi) in enumerate(age_bins): if len(y_true_mc[h_idx][bin_idx]) == 0: # No samples in this bin for this (mc, tau) for cause_k in range(n_causes_eval): cause_id = int(cause_k) if cause_ids is None else int( cfg.cause_ids[cause_k]) for k_percent in topk_percents: rows_by_bin.append( dict( mc_idx=global_mc_idx, age_bin_id=bin_idx, age_bin_low=float(a_lo), age_bin_high=float(a_hi), horizon_tau=float(tau_y), topk_percent=float(k_percent), cause_id=cause_id, n_samples=0, n_positives=0, auc=float("nan"), auprc=float("nan"), recall_at_K=float("nan"), precision_at_K=float("nan"), brier_score=float("nan"), ) ) continue yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0) pb_t = torch.cat(y_pred_mc[h_idx][bin_idx], dim=0) if tuple(yb_t.shape) != tuple(pb_t.shape): raise ValueError( f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}" ) n_samples = int(yb_t.size(0)) metrics = compute_binary_metrics_torch( y_true=yb_t, y_pred=pb_t, k_percents=topk_percents, tie_mode="exact", chunk_size=128, compute_ici=False, ) # Move just the metric vectors to CPU once per (mc, tau, bin) # for DataFrame construction. auc = metrics.auc_per_cause.detach().cpu().numpy() auprc = metrics.ap_per_cause.detach().cpu().numpy() brier = metrics.brier_per_cause.detach().cpu().numpy() n_pos = metrics.n_pos_per_cause.detach().cpu().numpy() prec_at_k = metrics.precision_at_k.detach().cpu().numpy() # (P,K) rec_at_k = metrics.recall_at_k.detach().cpu().numpy() # (P,K) for cause_k in range(n_causes_eval): cause_id = int(cause_k) if cause_ids is None else int( cfg.cause_ids[cause_k]) for p_idx, k_percent in enumerate(topk_percents): rows_by_bin.append( dict( mc_idx=global_mc_idx, age_bin_id=bin_idx, age_bin_low=float(a_lo), age_bin_high=float(a_hi), horizon_tau=float(tau_y), topk_percent=float(k_percent), cause_id=cause_id, n_samples=n_samples, n_positives=int(n_pos[cause_k]), auc=float(auc[cause_k]), auprc=float(auprc[cause_k]), recall_at_K=float(rec_at_k[p_idx, cause_k]), precision_at_K=float(prec_at_k[p_idx, cause_k]), brier_score=float(brier[cause_k]), ) ) df_by_bin = pd.DataFrame(rows_by_bin) df_agg = aggregate_age_bin_results(df_by_bin) return df_by_bin, df_agg