from __future__ import annotations import math from dataclasses import dataclass from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np import pandas as pd import torch try: from tqdm import tqdm except Exception: # pragma: no cover def tqdm(x, **kwargs): return x from utils import ( multi_hot_ever_within_horizon, multi_hot_selected_causes_within_horizon, sample_context_in_fixed_age_bin, ) from torch_metrics import compute_binary_metrics_torch def _nanmean_np(x: np.ndarray, axis: int = 0) -> np.ndarray: with np.errstate(invalid="ignore"): return np.nanmean(x, axis=axis) def _nanstd_np_ddof1(x: np.ndarray, axis: int = 0) -> np.ndarray: """NaN-aware sample std (ddof=1), matching pandas std() semantics.""" x = np.asarray(x, dtype=float) mask = np.isfinite(x) cnt = mask.sum(axis=axis) # mean over finite entries x0 = np.where(mask, x, 0.0) mean = x0.sum(axis=axis) / np.maximum(cnt, 1) # sum of squared deviations over finite entries dev2 = np.where(mask, (x - np.expand_dims(mean, axis=axis)) ** 2, 0.0) ss = dev2.sum(axis=axis) denom = cnt - 1 out = np.sqrt(ss / np.maximum(denom, 1)) out = np.where(denom > 0, out, np.nan) return out def _weighted_mean_np(x: np.ndarray, w: np.ndarray, axis: int = 0) -> np.ndarray: """NaN-aware weighted mean. Only bins with finite x contribute to both numerator and denominator. If denom==0 -> NaN. """ x = np.asarray(x, dtype=float) w = np.asarray(w, dtype=float) if axis != 0: raise ValueError("_weighted_mean_np currently supports axis=0 only") # Broadcast weights along trailing dims of x. while w.ndim < x.ndim: w = w[..., None] w = np.broadcast_to(w, x.shape) mask = np.isfinite(x) num = np.where(mask, x * w, 0.0).sum(axis=0) denom = np.where(mask, w, 0.0).sum(axis=0) return np.where(denom > 0.0, num / denom, np.nan) def _blocks_to_df_by_bin( blocks: List[Dict[str, Any]], *, topk_percents: np.ndarray, ) -> pd.DataFrame: """Convert per-block column vectors into the long-format per-bin DataFrame. This does a single vectorized reshape per block (cause-major ordering), and concatenates columns once at the end. """ if len(blocks) == 0: return pd.DataFrame( columns=[ "mc_idx", "age_bin_id", "age_bin_low", "age_bin_high", "horizon_tau", "topk_percent", "cause_id", "n_samples", "n_positives", "auc", "auprc", "recall_at_K", "precision_at_K", "brier_score", ] ) P = int(topk_percents.size) cols: Dict[str, List[np.ndarray]] = { "mc_idx": [], "age_bin_id": [], "age_bin_low": [], "age_bin_high": [], "horizon_tau": [], "topk_percent": [], "cause_id": [], "n_samples": [], "n_positives": [], "auc": [], "auprc": [], "recall_at_K": [], "precision_at_K": [], "brier_score": [], } for blk in blocks: cause_id = np.asarray(blk["cause_id"], dtype=int) K = int(cause_id.size) n_rows = K * P cols["mc_idx"].append(np.full(n_rows, int(blk["mc_idx"]), dtype=int)) cols["age_bin_id"].append( np.full(n_rows, int(blk["age_bin_id"]), dtype=int)) cols["age_bin_low"].append( np.full(n_rows, float(blk["age_bin_low"]), dtype=float)) cols["age_bin_high"].append( np.full(n_rows, float(blk["age_bin_high"]), dtype=float)) cols["horizon_tau"].append( np.full(n_rows, float(blk["horizon_tau"]), dtype=float)) cols["cause_id"].append(np.repeat(cause_id, P)) cols["topk_percent"].append(np.tile(topk_percents.astype(float), K)) cols["n_samples"].append( np.full(n_rows, int(blk["n_samples"]), dtype=int)) n_pos = np.asarray(blk["n_positives"], dtype=int) cols["n_positives"].append(np.repeat(n_pos, P)) auc = np.asarray(blk["auc"], dtype=float) auprc = np.asarray(blk["auprc"], dtype=float) brier = np.asarray(blk["brier_score"], dtype=float) cols["auc"].append(np.repeat(auc, P)) cols["auprc"].append(np.repeat(auprc, P)) cols["brier_score"].append(np.repeat(brier, P)) # precision/recall are stored as (P,K); we want cause-major rows, i.e. # (K,P) then flatten. prec = np.asarray(blk["precision_at_K"], dtype=float) rec = np.asarray(blk["recall_at_K"], dtype=float) if prec.shape != (P, K) or rec.shape != (P, K): raise ValueError( f"Expected precision/recall shapes (P,K)=({P},{K}); got {prec.shape} and {rec.shape}" ) cols["precision_at_K"].append(prec.T.reshape(-1)) cols["recall_at_K"].append(rec.T.reshape(-1)) out = {k: np.concatenate(v, axis=0) for k, v in cols.items()} return pd.DataFrame(out) def aggregate_metrics_columnar( blocks: List[Dict[str, Any]], *, topk_percents: np.ndarray, cause_id: np.ndarray, ) -> pd.DataFrame: """Aggregate per-bin results across age bins (macro/weighted) and MC (mean/std). This is a vectorized, columnar replacement for the old pandas groupby/apply. Semantics match the previous implementation: - bins with n_samples==0 are excluded from bin-aggregation - macro: unweighted mean over bins (NaN-aware) - weighted: weighted mean over bins using weights=n_samples (NaN-aware) - across MC: mean/std (ddof=1), NaN-aware """ if len(blocks) == 0: return pd.DataFrame( columns=[ "agg_type", "horizon_tau", "topk_percent", "cause_id", "n_mc", "n_bins_used_mean", "n_samples_total_mean", "n_positives_total_mean", "auc_mean", "auc_std", "auprc_mean", "auprc_std", "recall_at_K_mean", "recall_at_K_std", "precision_at_K_mean", "precision_at_K_std", "brier_score_mean", "brier_score_std", ] ) P = int(topk_percents.size) cause_id = np.asarray(cause_id, dtype=int) K = int(cause_id.size) # Group blocks by (mc_idx, horizon_tau) keys: List[Tuple[int, float]] = [] grouped: Dict[Tuple[int, float], List[Dict[str, Any]]] = {} for blk in blocks: key = (int(blk["mc_idx"]), float(blk["horizon_tau"])) if key not in grouped: grouped[key] = [] keys.append(key) grouped[key].append(blk) mc_vals = sorted({k[0] for k in keys}) tau_vals = sorted({k[1] for k in keys}) M = len(mc_vals) T = len(tau_vals) mc_index = {mc: i for i, mc in enumerate(mc_vals)} tau_index = {tau: i for i, tau in enumerate(tau_vals)} # Per (agg_type, mc, tau): store arrays # metrics: (M,T,K) and (M,T,P,K) auc_macro = np.full((M, T, K), np.nan, dtype=float) auc_weighted = np.full((M, T, K), np.nan, dtype=float) ap_macro = np.full((M, T, K), np.nan, dtype=float) ap_weighted = np.full((M, T, K), np.nan, dtype=float) brier_macro = np.full((M, T, K), np.nan, dtype=float) brier_weighted = np.full((M, T, K), np.nan, dtype=float) prec_macro = np.full((M, T, P, K), np.nan, dtype=float) prec_weighted = np.full((M, T, P, K), np.nan, dtype=float) rec_macro = np.full((M, T, P, K), np.nan, dtype=float) rec_weighted = np.full((M, T, P, K), np.nan, dtype=float) n_bins_used = np.zeros((M, T), dtype=float) n_samples_total = np.zeros((M, T), dtype=float) n_pos_total = np.zeros((M, T, K), dtype=float) for (mc, tau), blks in grouped.items(): mi = mc_index[mc] ti = tau_index[tau] # keep only bins with n_samples>0 blks_nz = [b for b in blks if int(b["n_samples"]) > 0] if len(blks_nz) == 0: n_bins_used[mi, ti] = 0.0 n_samples_total[mi, ti] = 0.0 n_pos_total[mi, ti, :] = 0.0 continue w = np.asarray([int(b["n_samples"]) for b in blks_nz], dtype=float) # (B,) n_bins_used[mi, ti] = float(len(w)) n_samples_total[mi, ti] = float(w.sum()) npos = np.stack([np.asarray(b["n_positives"], dtype=float) for b in blks_nz], axis=0) # (B,K) n_pos_total[mi, ti, :] = npos.sum(axis=0) auc_b = np.stack([np.asarray(b["auc"], dtype=float) for b in blks_nz], axis=0) # (B,K) ap_b = np.stack([np.asarray(b["auprc"], dtype=float) for b in blks_nz], axis=0) brier_b = np.stack([np.asarray(b["brier_score"], dtype=float) for b in blks_nz], axis=0) auc_macro[mi, ti, :] = _nanmean_np(auc_b, axis=0) ap_macro[mi, ti, :] = _nanmean_np(ap_b, axis=0) brier_macro[mi, ti, :] = _nanmean_np(brier_b, axis=0) auc_weighted[mi, ti, :] = _weighted_mean_np(auc_b, w, axis=0) ap_weighted[mi, ti, :] = _weighted_mean_np(ap_b, w, axis=0) brier_weighted[mi, ti, :] = _weighted_mean_np(brier_b, w, axis=0) prec_b = np.stack([np.asarray(b["precision_at_K"], dtype=float) for b in blks_nz], axis=0) # (B,P,K) rec_b = np.stack([np.asarray(b["recall_at_K"], dtype=float) for b in blks_nz], axis=0) # macro mean over bins prec_macro[mi, ti, :, :] = _nanmean_np(prec_b, axis=0) rec_macro[mi, ti, :, :] = _nanmean_np(rec_b, axis=0) # weighted mean over bins (weights along bin axis) w3 = w.reshape(-1, 1, 1) prec_weighted[mi, ti, :, :] = _weighted_mean_np(prec_b, w3, axis=0) rec_weighted[mi, ti, :, :] = _weighted_mean_np(rec_b, w3, axis=0) # Across-MC aggregation (mean/std), then emit long-format df keyed by # (agg_type, horizon_tau, topk_percent, cause_id) rows: Dict[str, List[np.ndarray]] = { "agg_type": [], "horizon_tau": [], "topk_percent": [], "cause_id": [], "n_mc": [], "n_bins_used_mean": [], "n_samples_total_mean": [], "n_positives_total_mean": [], "auc_mean": [], "auc_std": [], "auprc_mean": [], "auprc_std": [], "recall_at_K_mean": [], "recall_at_K_std": [], "precision_at_K_mean": [], "precision_at_K_std": [], "brier_score_mean": [], "brier_score_std": [], } cause_long = np.repeat(cause_id, P) topk_long = np.tile(topk_percents.astype(float), K) n_mc_val = float(M) for ti, tau in enumerate(tau_vals): # scalar totals (repeat across causes/topk) n_bins_mean = float( np.mean(n_bins_used[:, ti])) if M > 0 else float("nan") n_samp_mean = float( np.mean(n_samples_total[:, ti])) if M > 0 else float("nan") n_pos_mean = _nanmean_np(n_pos_total[:, ti, :], axis=0) # (K,) for agg_type in ("macro", "weighted"): if agg_type == "macro": auc_m = _nanmean_np(auc_macro[:, ti, :], axis=0) auc_s = _nanstd_np_ddof1(auc_macro[:, ti, :], axis=0) ap_m = _nanmean_np(ap_macro[:, ti, :], axis=0) ap_s = _nanstd_np_ddof1(ap_macro[:, ti, :], axis=0) brier_m = _nanmean_np(brier_macro[:, ti, :], axis=0) brier_s = _nanstd_np_ddof1(brier_macro[:, ti, :], axis=0) prec_m = _nanmean_np(prec_macro[:, ti, :, :], axis=0) # (P,K) prec_s = _nanstd_np_ddof1(prec_macro[:, ti, :, :], axis=0) rec_m = _nanmean_np(rec_macro[:, ti, :, :], axis=0) rec_s = _nanstd_np_ddof1(rec_macro[:, ti, :, :], axis=0) else: auc_m = _nanmean_np(auc_weighted[:, ti, :], axis=0) auc_s = _nanstd_np_ddof1(auc_weighted[:, ti, :], axis=0) ap_m = _nanmean_np(ap_weighted[:, ti, :], axis=0) ap_s = _nanstd_np_ddof1(ap_weighted[:, ti, :], axis=0) brier_m = _nanmean_np(brier_weighted[:, ti, :], axis=0) brier_s = _nanstd_np_ddof1(brier_weighted[:, ti, :], axis=0) prec_m = _nanmean_np(prec_weighted[:, ti, :, :], axis=0) prec_s = _nanstd_np_ddof1(prec_weighted[:, ti, :, :], axis=0) rec_m = _nanmean_np(rec_weighted[:, ti, :, :], axis=0) rec_s = _nanstd_np_ddof1(rec_weighted[:, ti, :, :], axis=0) n_rows = K * P rows["agg_type"].append(np.full(n_rows, agg_type, dtype=object)) rows["horizon_tau"].append(np.full(n_rows, float(tau), dtype=float)) rows["topk_percent"].append(topk_long) rows["cause_id"].append(cause_long) rows["n_mc"].append(np.full(n_rows, n_mc_val, dtype=float)) rows["n_bins_used_mean"].append( np.full(n_rows, n_bins_mean, dtype=float)) rows["n_samples_total_mean"].append( np.full(n_rows, n_samp_mean, dtype=float)) rows["n_positives_total_mean"].append(np.repeat(n_pos_mean, P)) rows["auc_mean"].append(np.repeat(auc_m, P)) rows["auc_std"].append(np.repeat(auc_s, P)) rows["auprc_mean"].append(np.repeat(ap_m, P)) rows["auprc_std"].append(np.repeat(ap_s, P)) rows["brier_score_mean"].append(np.repeat(brier_m, P)) rows["brier_score_std"].append(np.repeat(brier_s, P)) rows["precision_at_K_mean"].append(prec_m.T.reshape(-1)) rows["precision_at_K_std"].append(prec_s.T.reshape(-1)) rows["recall_at_K_mean"].append(rec_m.T.reshape(-1)) rows["recall_at_K_std"].append(rec_s.T.reshape(-1)) out = {k: np.concatenate(v, axis=0) for k, v in rows.items()} df = pd.DataFrame(out) return df.sort_values( ["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True ) def aggregate_age_bin_results(df_by_bin: pd.DataFrame) -> pd.DataFrame: """Aggregate per-bin age evaluation results. Produces both: - macro: unweighted mean over bins with n_samples>0 - weighted: weighted mean over bins using weights=n_samples Then aggregates across MC repetitions (mean/std). Requires df_by_bin to include: mc_idx, horizon_tau, topk_percent, cause_id, age_bin_id, n_samples, n_positives, auc, auprc, recall_at_K, precision_at_K, brier_score Returns: DataFrame keyed by (agg_type, horizon_tau, topk_percent, cause_id) """ if df_by_bin is None or len(df_by_bin) == 0: return pd.DataFrame( columns=[ "agg_type", "horizon_tau", "topk_percent", "cause_id", "n_mc", "n_bins_used_mean", "n_samples_total_mean", "n_positives_total_mean", "auc_mean", "auc_std", "auprc_mean", "auprc_std", "recall_at_K_mean", "recall_at_K_std", "precision_at_K_mean", "precision_at_K_std", "brier_score_mean", "brier_score_std", ] ) def _bin_aggregate(group: pd.DataFrame, *, weighted: bool) -> pd.Series: g = group[group["n_samples"] > 0] if len(g) == 0: return pd.Series( dict( n_bins_used=0, n_samples_total=0, n_positives_total=0, auc=float("nan"), auprc=float("nan"), recall_at_K=float("nan"), precision_at_K=float("nan"), brier_score=float("nan"), ) ) n_bins_used = int(g["age_bin_id"].nunique()) n_samples_total = int(g["n_samples"].sum()) n_positives_total = int(g["n_positives"].sum()) if not weighted: return pd.Series( dict( n_bins_used=n_bins_used, n_samples_total=n_samples_total, n_positives_total=n_positives_total, auc=float(g["auc"].mean()), auprc=float(g["auprc"].mean()), recall_at_K=float(g["recall_at_K"].mean()), precision_at_K=float(g["precision_at_K"].mean()), brier_score=float(g["brier_score"].mean()), ) ) w = g["n_samples"].to_numpy(dtype=float) w_sum = float(w.sum()) if w_sum <= 0.0: return pd.Series( dict( n_bins_used=n_bins_used, n_samples_total=n_samples_total, n_positives_total=n_positives_total, auc=float("nan"), auprc=float("nan"), recall_at_K=float("nan"), precision_at_K=float("nan"), brier_score=float("nan"), ) ) def _wavg(col: str) -> float: return float(np.average(g[col].to_numpy(dtype=float), weights=w)) return pd.Series( dict( n_bins_used=n_bins_used, n_samples_total=n_samples_total, n_positives_total=n_positives_total, auc=_wavg("auc"), auprc=_wavg("auprc"), recall_at_K=_wavg("recall_at_K"), precision_at_K=_wavg("precision_at_K"), brier_score=_wavg("brier_score"), ) ) # Kept for backward compatibility (e.g., if callers load a CSV and need to # aggregate). Prefer `aggregate_metrics_columnar` during evaluation. group_keys = ["mc_idx", "horizon_tau", "topk_percent", "cause_id"] df = df_by_bin[df_by_bin["n_samples"] > 0].copy() if len(df) == 0: return pd.DataFrame( columns=[ "agg_type", "horizon_tau", "topk_percent", "cause_id", "n_mc", "n_bins_used_mean", "n_samples_total_mean", "n_positives_total_mean", "auc_mean", "auc_std", "auprc_mean", "auprc_std", "recall_at_K_mean", "recall_at_K_std", "precision_at_K_mean", "precision_at_K_std", "brier_score_mean", "brier_score_std", ] ) # Macro: mean over bins df_mc_macro = ( df.groupby(group_keys, as_index=False) .agg( n_bins_used=("age_bin_id", "nunique"), n_samples_total=("n_samples", "sum"), n_positives_total=("n_positives", "sum"), auc=("auc", "mean"), auprc=("auprc", "mean"), recall_at_K=("recall_at_K", "mean"), precision_at_K=("precision_at_K", "mean"), brier_score=("brier_score", "mean"), ) ) df_mc_macro["agg_type"] = "macro" # Weighted: weighted mean over bins with weights=n_samples, NaN-aware per metric w = df["n_samples"].astype(float) df_w = df.copy() for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]: m = df_w[col].astype(float) ww = w.where(m.notna(), other=0.0) df_w[f"__num_{col}"] = (m.fillna(0.0) * w) df_w[f"__den_{col}"] = ww df_mc_w = df_w.groupby(group_keys, as_index=False).agg( n_bins_used=("age_bin_id", "nunique"), n_samples_total=("n_samples", "sum"), n_positives_total=("n_positives", "sum"), **{f"__num_{c}": (f"__num_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]}, **{f"__den_{c}": (f"__den_{c}", "sum") for c in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]}, ) for col in ["auc", "auprc", "recall_at_K", "precision_at_K", "brier_score"]: num = df_mc_w[f"__num_{col}"].astype(float) den = df_mc_w[f"__den_{col}"].astype(float) df_mc_w[col] = (num / den).where(den > 0.0, other=float("nan")) df_mc_w.drop(columns=[f"__num_{col}", f"__den_{col}"], inplace=True) df_mc_w["agg_type"] = "weighted" df_mc_binagg = pd.concat([df_mc_macro, df_mc_w], ignore_index=True) df_agg = ( df_mc_binagg.groupby( ["agg_type", "horizon_tau", "topk_percent", "cause_id"], as_index=False) .agg( n_mc=("mc_idx", "nunique"), n_bins_used_mean=("n_bins_used", "mean"), n_samples_total_mean=("n_samples_total", "mean"), n_positives_total_mean=("n_positives_total", "mean"), auc_mean=("auc", "mean"), auc_std=("auc", "std"), auprc_mean=("auprc", "mean"), auprc_std=("auprc", "std"), recall_at_K_mean=("recall_at_K", "mean"), recall_at_K_std=("recall_at_K", "std"), precision_at_K_mean=("precision_at_K", "mean"), precision_at_K_std=("precision_at_K", "std"), brier_score_mean=("brier_score", "mean"), brier_score_std=("brier_score", "std"), ) .sort_values(["agg_type", "horizon_tau", "topk_percent", "cause_id"], ignore_index=True) ) return df_agg # NOTE: metric computation is torch/GPU-native in `torch_metrics.py`. # NumPy/Pandas are only used for final CSV formatting/aggregation. @dataclass class EvalAgeConfig: horizons_years: Sequence[float] age_bins: Sequence[Tuple[float, float]] topk_percents: Sequence[float] = (1.0, 5.0, 10.0, 20.0, 50.0) n_mc: int = 5 seed: int = 0 cause_ids: Optional[Sequence[int]] = None store_per_cause: bool = True @torch.inference_mode() def evaluate_time_dependent_age_bins( model: torch.nn.Module, head: torch.nn.Module, criterion, dataloader: torch.utils.data.DataLoader, n_disease: int, cfg: EvalAgeConfig, device: str | torch.device, mc_offset: int = 0, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """Delphi-2M-style age-bin evaluation with strict horizon alignment. Semantics (strict): for each (MC, horizon tau, age bin) we independently: - build the eligible token set within that bin - enforce follow-up coverage: t_ctx + tau <= t_end - randomly sample exactly one token per individual within the bin (de-dup) - recompute context representations and predictions for that (tau, bin) Returns: df_by_bin: rows keyed by (mc_idx, age_bin_id, horizon_tau, topk_percent, cause_id) df_agg: aggregated metrics across age bins and MC with agg_type in {macro, weighted} """ device = torch.device(device) model.eval() head.eval() horizons_years = [float(x) for x in cfg.horizons_years] if len(horizons_years) == 0: raise ValueError("cfg.horizons_years must be non-empty") age_bins = [(float(a), float(b)) for (a, b) in cfg.age_bins] if len(age_bins) == 0: raise ValueError("cfg.age_bins must be non-empty") for (a, b) in age_bins: if not (b > a): raise ValueError( f"age_bins must be (low, high) with high>low; got {(a, b)}") topk_percents = [float(x) for x in cfg.topk_percents] if len(topk_percents) == 0: raise ValueError("cfg.topk_percents must be non-empty") if any((p <= 0.0 or p > 100.0) for p in topk_percents): raise ValueError( f"All topk_percents must be in (0,100]; got {topk_percents}") if int(cfg.n_mc) <= 0: raise ValueError("cfg.n_mc must be >= 1") if cfg.cause_ids is None: cause_ids = None n_causes_eval = int(n_disease) cause_id_vec = np.arange(n_causes_eval, dtype=int) else: cause_ids = torch.tensor( list(cfg.cause_ids), dtype=torch.long, device=device) n_causes_eval = int(cause_ids.numel()) cause_id_vec = np.asarray(list(cfg.cause_ids), dtype=int) topk_percents_np = np.asarray(topk_percents, dtype=float) # Columnar per-(mc,tau,bin) blocks; avoids Python per-row dict appends. blocks: List[Dict[str, Any]] = [] for mc_idx in range(int(cfg.n_mc)): global_mc_idx = int(mc_offset) + int(mc_idx) # Storage for this MC only: (tau, bin) -> list of GPU tensors. # This keeps computations GPU-first while preventing a factor-n_mc # blow-up in GPU memory. y_true_mc: List[List[List[torch.Tensor]]] = [ [[] for _ in range(len(age_bins))] for _ in range(len(horizons_years)) ] y_pred_mc: List[List[List[torch.Tensor]]] = [ [[] for _ in range(len(age_bins))] for _ in range(len(horizons_years)) ] # tqdm over batches; include MC idx in description. for batch_idx, batch in enumerate( tqdm(dataloader, desc=f"Evaluating (MC {mc_idx+1}/{cfg.n_mc})", unit="batch") ): event_seq, time_seq, cont_feats, cate_feats, sexes = batch event_seq = event_seq.to(device) time_seq = time_seq.to(device) cont_feats = cont_feats.to(device) cate_feats = cate_feats.to(device) sexes = sexes.to(device) B = int(event_seq.size(0)) b = torch.arange(B, device=device) # Hoist backbone forward pass: inputs are identical across (tau, age_bin) # within this batch, so this is safe and numerically identical. h = model(event_seq, time_seq, sexes, cont_feats, cate_feats) # (B,L,D) for tau_idx, tau_y in enumerate(horizons_years): tau_tensor = torch.tensor(float(tau_y), device=device) for bin_idx, (a_lo, a_hi) in enumerate(age_bins): # Diversify RNG stream across MC/tau/bin/batch to reduce correlation. seed = ( int(cfg.seed) + (100_000 * int(global_mc_idx)) + (1_000 * int(tau_idx)) + (10 * int(bin_idx)) + int(batch_idx) ) keep, t_ctx = sample_context_in_fixed_age_bin( event_seq=event_seq, time_seq=time_seq, tau_years=float(tau_y), age_bin=(float(a_lo), float(a_hi)), seed=seed, ) if not keep.any(): continue # Bin-specific prediction: context indices differ per (tau, bin) # but the backbone features do not. c = h[b, t_ctx] logits = head(c) cifs = criterion.calculate_cifs( logits, taus=tau_tensor ) if cifs.ndim != 2: raise ValueError( "criterion.calculate_cifs must return (B,K) for scalar tau; " f"got shape={tuple(cifs.shape)}" ) if cause_ids is None: y = multi_hot_ever_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau_y), n_disease=n_disease, ) preds = cifs else: y = multi_hot_selected_causes_within_horizon( event_seq=event_seq, time_seq=time_seq, t_ctx=t_ctx, tau_years=float(tau_y), cause_ids=cause_ids, n_disease=n_disease, ) preds = cifs.index_select(dim=1, index=cause_ids) y_true_mc[tau_idx][bin_idx].append( y[keep].detach().to(dtype=torch.bool) ) y_pred_mc[tau_idx][bin_idx].append( preds[keep].detach().to(dtype=torch.float32) ) # Aggregate this MC immediately (frees GPU memory before next MC). for h_idx, tau_y in enumerate(horizons_years): for bin_idx, (a_lo, a_hi) in enumerate(age_bins): if len(y_true_mc[h_idx][bin_idx]) == 0: # No samples in this bin for this (mc, tau): store a single # block with NaN metric vectors. K = int(n_causes_eval) P = int(topk_percents_np.size) blocks.append( dict( mc_idx=global_mc_idx, age_bin_id=bin_idx, age_bin_low=float(a_lo), age_bin_high=float(a_hi), horizon_tau=float(tau_y), n_samples=0, cause_id=cause_id_vec, n_positives=np.zeros((K,), dtype=int), auc=np.full((K,), np.nan, dtype=float), auprc=np.full((K,), np.nan, dtype=float), brier_score=np.full((K,), np.nan, dtype=float), precision_at_K=np.full((P, K), np.nan, dtype=float), recall_at_K=np.full((P, K), np.nan, dtype=float), ) ) continue yb_t = torch.cat(y_true_mc[h_idx][bin_idx], dim=0) pb_t = torch.cat(y_pred_mc[h_idx][bin_idx], dim=0) if tuple(yb_t.shape) != tuple(pb_t.shape): raise ValueError( f"Shape mismatch mc={mc_idx} tau={tau_y} bin={bin_idx}: y{tuple(yb_t.shape)} vs p{tuple(pb_t.shape)}" ) n_samples = int(yb_t.size(0)) metrics = compute_binary_metrics_torch( y_true=yb_t, y_pred=pb_t, k_percents=topk_percents, tie_mode="exact", chunk_size=128, compute_ici=False, ) # Collect a single columnar block (vectors, not per-row dicts). blocks.append( dict( mc_idx=global_mc_idx, age_bin_id=bin_idx, age_bin_low=float(a_lo), age_bin_high=float(a_hi), horizon_tau=float(tau_y), n_samples=int(n_samples), cause_id=cause_id_vec, n_positives=metrics.n_pos_per_cause.detach().cpu().numpy().astype(int), auc=metrics.auc_per_cause.detach().cpu().numpy().astype(float), auprc=metrics.ap_per_cause.detach().cpu().numpy().astype(float), brier_score=metrics.brier_per_cause.detach().cpu().numpy().astype(float), precision_at_K=metrics.precision_at_k.detach().cpu().numpy().astype(float), recall_at_K=metrics.recall_at_k.detach().cpu().numpy().astype(float), ) ) # Aggregation is computed from columnar blocks (fast, no pandas apply). df_agg = aggregate_metrics_columnar( blocks, topk_percents=topk_percents_np, cause_id=cause_id_vec, ) if bool(cfg.store_per_cause): df_by_bin = _blocks_to_df_by_bin(blocks, topk_percents=topk_percents_np) else: df_by_bin = pd.DataFrame( columns=[ "mc_idx", "age_bin_id", "age_bin_low", "age_bin_high", "horizon_tau", "topk_percent", "cause_id", "n_samples", "n_positives", "auc", "auprc", "recall_at_K", "precision_at_K", "brier_score", ] ) return df_by_bin, df_agg