diff --git a/evaluate_models.py b/evaluate_models.py index 83b783c..a1165df 100644 --- a/evaluate_models.py +++ b/evaluate_models.py @@ -4,7 +4,6 @@ import json import math import os import random -import statistics import sys import time from concurrent.futures import ThreadPoolExecutor, as_completed @@ -522,7 +521,7 @@ def check_cif_integrity( # Metrics # ============================================================ -# --- Standard fast DeLong AUC variance + CI (ties handled via midranks) --- +# --- Rank-based ROC AUC (ties handled via midranks) --- def compute_midrank(x: np.ndarray) -> np.ndarray: """Vectorized midrank computation (ties -> average ranks).""" @@ -554,75 +553,6 @@ def compute_midrank(x: np.ndarray) -> np.ndarray: return out -def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]: - """Fast DeLong method for computing AUC covariance. - - predictions_sorted_transposed: shape (n_classifiers, n_examples) with positive examples first. - """ - preds = np.asarray(predictions_sorted_transposed, dtype=float) - m = int(label_1_count) - n = int(preds.shape[1] - m) - if m <= 0 or n <= 0: - return np.array([float("nan")]), np.array([[float("nan")]]) - - pos = preds[:, :m] - neg = preds[:, m:] - - tx = np.array([compute_midrank(x) for x in pos]) - ty = np.array([compute_midrank(x) for x in neg]) - tz = np.array([compute_midrank(x) for x in preds]) - - aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n) - - v01 = (tz[:, :m] - tx) / n - v10 = 1.0 - (tz[:, m:] - ty) / m - - if v01.shape[0] > 1: - sx = np.cov(v01) - sy = np.cov(v10) - else: - # Single-classifier case: compute row-wise variance (do not flatten). - var_v01 = float(np.var(v01, axis=1, ddof=1)[0]) - var_v10 = float(np.var(v10, axis=1, ddof=1)[0]) - sx = np.array([[var_v01]]) - sy = np.array([[var_v10]]) - delong_cov = sx / m + sy / n - return aucs, delong_cov - - -def calc_auc_variance(ground_truth: np.ndarray, predictions: np.ndarray) -> Tuple[float, float]: - y = np.asarray(ground_truth, dtype=int) - p = np.asarray(predictions, dtype=float) - if y.ndim != 1 or p.ndim != 1 or y.shape[0] != p.shape[0]: - raise ValueError("calc_auc_variance expects 1D arrays of equal length") - - m = int(np.sum(y == 1)) - n = int(np.sum(y == 0)) - if m == 0 or n == 0: - return float("nan"), float("nan") - - order = np.argsort(-y) # positives first - preds_sorted = p[order] - aucs, cov = fastDeLong(preds_sorted[np.newaxis, :], m) - auc = float(aucs[0]) - var = float(cov[0, 0]) - return auc, var - - -def delong_ci(ground_truth: np.ndarray, predictions: np.ndarray, alpha: float = 0.95) -> Tuple[float, float, float]: - """Return (auc, ci_low, ci_high) using DeLong variance and normal CI.""" - auc, var = calc_auc_variance(ground_truth, predictions) - if not np.isfinite(var) or var <= 0: - print("WARNING: DeLong variance is non-positive or NaN; CI set to NaN") - return float(auc), float("nan"), float("nan") - - sd = math.sqrt(var) - z = statistics.NormalDist().inv_cdf(1.0 - (1.0 - float(alpha)) / 2.0) - lo = max(0.0, auc - z * sd) - hi = min(1.0, auc + z * sd) - return float(auc), float(lo), float(hi) - - def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float: """Rank-based ROC AUC via Mann–Whitney U statistic (ties handled by midranks). @@ -643,49 +573,6 @@ def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float: return float(auc) -def bootstrap_auc_ci( - scores: np.ndarray, - labels: np.ndarray, - n_bootstrap: int, - alpha: float = 0.95, - seed: int = 0, -) -> Tuple[float, float, float]: - """Bootstrap CI for ROC AUC (percentile).""" - rng = np.random.default_rng(int(seed)) - scores = np.asarray(scores, dtype=float) - labels = np.asarray(labels, dtype=int) - n = labels.shape[0] - if n == 0 or np.all(labels == labels[0]): - print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN") - return float("nan"), float("nan"), float("nan") - - auc_full = roc_auc_rank(labels, scores) - if not np.isfinite(auc_full): - print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN") - return float("nan"), float("nan"), float("nan") - - aucs: List[float] = [] - for _ in range(int(n_bootstrap)): - idx = rng.integers(0, n, size=n) - yb = labels[idx] - if np.all(yb == yb[0]): - continue - pb = scores[idx] - auc = roc_auc_rank(yb, pb) - if np.isfinite(auc): - aucs.append(float(auc)) - - if len(aucs) < 10: - print("WARNING: bootstrap AUC CI has too few valid resamples; CI set to NaN") - return float(auc_full), float("nan"), float("nan") - - lo_q = (1.0 - float(alpha)) / 2.0 - hi_q = 1.0 - lo_q - lo = float(np.quantile(aucs, lo_q)) - hi = float(np.quantile(aucs, hi_q)) - return float(auc_full), lo, hi - - def brier_score(p: np.ndarray, y: np.ndarray) -> float: p = np.asarray(p, dtype=float) y = np.asarray(y, dtype=float) @@ -1040,23 +927,75 @@ def compute_capture_points( return rows -def make_horizon_groups(horizons: Sequence[float]) -> Tuple[List[Dict[str, Any]], Dict[float, str], str]: - """Bucketize horizons into short/medium/long using the continuous-horizon rule.""" - uniq = sorted({float(h) for h in horizons}) - mapping: Dict[float, str] = {} - rows: List[Dict[str, Any]] = [] - # First 4 short, next 4 medium, rest long. - for i, h in enumerate(uniq): - if i < 4: - g, gr = "short", 1 - elif i < 8: - g, gr = "medium", 2 - else: - g, gr = "long", 3 - mapping[float(h)] = g - rows.append({"horizon": float(h), "group": g, "group_rank": int(gr)}) - method = "continuous_unique_horizons_first4_next4_rest" - return rows, mapping, method +def compute_event_rate_at_topk_causes( + p_tau: np.ndarray, + y_tau: np.ndarray, + topk_list: Sequence[int], +) -> List[Dict[str, Any]]: + """Compute Event Rate@K for cross-cause prioritization. + + For each individual, rank causes by predicted risk p_tau at a fixed horizon. + For each K, select top-K causes and compute the fraction that occur within the horizon. + + Args: + p_tau: (N, K) predicted CIFs at a fixed horizon + y_tau: (N, K) binary labels (0/1) whether cause occurs within the horizon + topk_list: list of K values to evaluate + + Returns: + List of rows with: topk, mean, median, n_total. + """ + p = np.asarray(p_tau, dtype=float) + y = np.asarray(y_tau, dtype=float) + if p.ndim != 2 or y.ndim != 2 or p.shape != y.shape: + raise ValueError( + "compute_event_rate_at_topk_causes expects (N,K) arrays of equal shape") + + n, k_total = p.shape + if n == 0 or k_total == 0: + out: List[Dict[str, Any]] = [] + for kk in topk_list: + out.append( + { + "topk": int(max(1, int(kk))), + "event_rate_mean": float("nan"), + "event_rate_median": float("nan"), + "n_total": int(n), + } + ) + return out + + # Sanitize K list. + topks = sorted({int(x) for x in topk_list if int(x) > 0}) + if not topks: + return [] + + max_k = min(int(max(topks)), int(k_total)) + if max_k <= 0: + return [] + + # Efficient: get top max_k causes per individual, then sort within those. + part = np.argpartition(-p, kth=max_k - 1, axis=1)[:, :max_k] # (N, max_k) + p_part = np.take_along_axis(p, part, axis=1) + order = np.argsort(-p_part, axis=1) + top_sorted = np.take_along_axis(part, order, axis=1) # (N, max_k) + + out_rows: List[Dict[str, Any]] = [] + for kk in topks: + kk_eff = min(int(kk), int(k_total)) + idx = top_sorted[:, :kk_eff] + y_sel = np.take_along_axis(y, idx, axis=1) + # fraction of selected causes that occur + per_person = np.mean(y_sel, axis=1) + out_rows.append( + { + "topk": int(kk_eff), + "event_rate_mean": float(np.mean(per_person)) if per_person.size else float("nan"), + "event_rate_median": float(np.median(per_person)) if per_person.size else float("nan"), + "n_total": int(n), + } + ) + return out_rows def count_occurs_within_horizon( @@ -1293,8 +1232,6 @@ def evaluate_one_model( out_rows: List[Dict[str, Any]], calib_rows: List[Dict[str, Any]], calib_cause_ids: Optional[Sequence[int]], - auc_ci_method: str, - bootstrap_n: int, n_calib_bins: int = 10, metric_workers: int = 0, progress: str = "auto", @@ -1375,15 +1312,8 @@ def evaluate_one_model( } ) - # Secondary: discrimination via AUC at the same horizon. - if auc_ci_method == "none": - auc, lo, hi = float("nan"), float("nan"), float("nan") - elif auc_ci_method == "bootstrap": - auc, lo, hi = bootstrap_auc_ci( - p, y, n_bootstrap=bootstrap_n, alpha=0.95 - ) - else: - auc, lo, hi = delong_ci(y, p, alpha=0.95) + # Secondary: discrimination via AUC at the same horizon (point estimate only). + auc = roc_auc_rank(y, p) local_rows.append( { @@ -1392,8 +1322,8 @@ def evaluate_one_model( "horizon": float(tau), "cause": int(cid), "value": float(auc), - "ci_low": lo, - "ci_high": hi, + "ci_low": "", + "ci_high": "", } ) @@ -1592,15 +1522,6 @@ def main() -> int: ap.add_argument("--integrity_strict", action="store_true", default=False) ap.add_argument("--integrity_tol", type=float, default=1e-6) - # AUC CI methods - ap.add_argument( - "--auc_ci_method", - type=str, - default="delong", - choices=["delong", "bootstrap", "none"], - ) - ap.add_argument("--bootstrap_n", type=int, default=2000) - # Speed/UX ap.add_argument( "--metric_workers", @@ -1630,6 +1551,15 @@ def main() -> int: default=50, help="If >0, also export a dense capture curve for k=1..max_pct", ) + + # High-risk cause concentration (cross-cause prioritization) + ap.add_argument( + "--cause_concentration_topk", + type=int, + nargs="*", + default=[5, 10, 20, 50], + help="Top-K causes per individual for Event Rate@K (cross-cause prioritization)", + ) args = ap.parse_args() set_deterministic(args.seed) @@ -1708,25 +1638,13 @@ def main() -> int: } ) - # Horizon groups for Experiment 3 - hg_rows, horizon_to_group, hg_method = make_horizon_groups( - args.eval_horizons) - write_simple_csv( - os.path.join(export_dir, "horizon_groups.csv"), - ["horizon", "group", "group_rank"], - hg_rows, - ) - summary_rows: List[Dict[str, Any]] = [] calib_rows: List[Dict[str, Any]] = [] # Experiment exports (accumulated across models) - rs_bins_rows: List[Dict[str, Any]] = [] - rs_sum_rows: List[Dict[str, Any]] = [] cap_points_rows: List[Dict[str, Any]] = [] cap_curve_rows: List[Dict[str, Any]] = [] - cal_group_sum_rows: List[Dict[str, Any]] = [] - cal_group_bins_rows: List[Dict[str, Any]] = [] + conc_rows: List[Dict[str, Any]] = [] # Track per-model integrity status for meta JSON. integrity_meta: Dict[str, Any] = {} @@ -1817,8 +1735,6 @@ def main() -> int: out_rows=model_rows, calib_rows=calib_rows, calib_cause_ids=top_cause_ids.tolist(), - auc_ci_method=str(args.auc_ci_method), - bootstrap_n=int(args.bootstrap_n), metric_workers=int(args.metric_workers), progress=str(args.progress), ) @@ -1831,61 +1747,37 @@ def main() -> int: ) summary_rows.extend(model_summary_rows) - # Convenience slices for user-facing experiments (focus causes only). - cause_cif_focus = cif_full[:, top_cause_ids, :] - y_within_focus = y_cause_within_tau[:, top_cause_ids, :] - # ============================================================ - # Experiment 1: Risk stratification bins + summary + # Experiment: High-Risk Cause Concentration at fixed horizon + # (cross-cause prioritization accuracy) # ============================================================ + topk_causes = [int(x) for x in args.cause_concentration_topk] for sex_label, sex_mask in _sex_slices(sex if sex.size else None): for h_i, tau in enumerate(args.eval_horizons): - for j, cause_id in enumerate(top_cause_ids.tolist()): - p = cause_cif_focus[:, j, h_i] - y = y_within_focus[:, j, h_i] - if sex_mask is not None: - p = p[sex_mask] - y = y[sex_mask] - q_used, bin_rows, summary = compute_risk_stratification_bins( - p, y, q_default=10) - for br in bin_rows: - rs_bins_rows.append( - { - "model_id": model_id, - "model_type": model_type, - "loss_type": loss_type_id, - "age_encoder": age_encoder, - "cov_type": cov_type, - "cause": int(cause_id), - "horizon": float(tau), - "sex": sex_label, - "q": int(br["q"]), - "n_bin": int(br["n_bin"]), - "p_mean": _safe_float(br["p_mean"]), - "y_rate": _safe_float(br["y_rate"]), - "y_overall": _safe_float(br["y_overall"]), - "lift_vs_overall": _safe_float(br["lift_vs_overall"]), - "q_total": int(q_used), - } - ) - rs_sum_rows.append( + p_tau_all = np.asarray(cif_full[:, :, h_i], dtype=float) + y_tau_all = np.asarray( + y_cause_within_tau[:, :, h_i], dtype=float) + if sex_mask is not None: + p_tau_all = p_tau_all[sex_mask] + y_tau_all = y_tau_all[sex_mask] + for rr in compute_event_rate_at_topk_causes(p_tau_all, y_tau_all, topk_causes): + conc_rows.append( { "model_id": model_id, "model_type": model_type, "loss_type": loss_type_id, "age_encoder": age_encoder, "cov_type": cov_type, - "cause": int(cause_id), "horizon": float(tau), "sex": sex_label, - "q_total": int(q_used), - "top_decile_y_rate": _safe_float(summary["top_decile_y_rate"]), - "bottom_half_y_rate": _safe_float(summary["bottom_half_y_rate"]), - "lift_top10_vs_bottom50": _safe_float(summary["lift_top10_vs_bottom50"]), - "slope_pred_vs_obs": _safe_float(summary["slope_pred_vs_obs"]), + **rr, } ) + # Convenience slices for user-facing experiments (focus causes only). + cause_cif_focus = cif_full[:, top_cause_ids, :] + y_within_focus = y_cause_within_tau[:, top_cause_ids, :] + # ============================================================ # Experiment 2: High-risk capture points (+ optional curve) # ============================================================ @@ -1932,129 +1824,6 @@ def main() -> int: } ) - # ============================================================ - # Experiment 3: Short/Medium/Long horizon-group calibration - # ============================================================ - # Per-horizon metrics for grouping - # Build a dict for quick access: (cause_id, horizon) -> (brier, ici) - per_h: Dict[Tuple[int, float], Dict[str, float]] = {} - for rr in model_rows: - if rr.get("metric_name") not in {"cause_brier", "cause_ici"}: - continue - try: - cid = int(rr.get("cause")) - except Exception: - continue - if cid not in set(int(x) for x in top_cause_ids.tolist()): - continue - h = _safe_float(rr.get("horizon")) - if not np.isfinite(h): - continue - key = (cid, float(h)) - d = per_h.get(key, {}) - d[str(rr.get("metric_name"))] = _safe_float(rr.get("value")) - per_h[key] = d - - # Compute group summaries and pooled bins using the same quantile bins as exp1 (per slice). - for sex_label, sex_mask in _sex_slices(sex if sex.size else None): - for j, cause_id in enumerate(top_cause_ids.tolist()): - # Decide Q per slice for pooled reliability curve - n_slice = int(np.sum(sex_mask)) if sex_mask is not None else int( - sex.shape[0]) - q_pool = 10 if n_slice >= 200 else 5 - - # Collect per-horizon brier/ici values - group_vals: Dict[str, Dict[str, List[float]]] = {"short": {"brier": [], "ici": [ - ]}, "medium": {"brier": [], "ici": []}, "long": {"brier": [], "ici": []}} - group_n_total: Dict[str, int] = { - "short": 0, "medium": 0, "long": 0} - - # Pooled bins: group -> q -> accumulators - pooled: Dict[str, Dict[int, Dict[str, float]]] = { - "short": {}, "medium": {}, "long": {}} - - for h_i, tau in enumerate(args.eval_horizons): - g = horizon_to_group.get(float(tau), "long") - - # brier/ici per horizon (already computed at full-sample level) - d = per_h.get((int(cause_id), float(tau)), {}) - brier_h = _safe_float(d.get("cause_brier")) - ici_h = _safe_float(d.get("cause_ici")) - if np.isfinite(brier_h): - group_vals[g]["brier"].append(brier_h) - if np.isfinite(ici_h): - group_vals[g]["ici"].append(ici_h) - - # pooled reliability bins from raw p/y - p = cause_cif_focus[:, j, h_i] - y = y_within_focus[:, j, h_i] - if sex_mask is not None: - p = p[sex_mask] - y = y[sex_mask] - if p.size == 0: - continue - edges = _quantile_edges(p, q_pool) - for qi in range(q_pool): - m = (p > edges[qi]) & (p <= edges[qi + 1]) - nb = int(np.sum(m)) - if nb == 0: - continue - pm = float(np.mean(p[m])) - yr = float(np.mean(y[m])) - acc = pooled[g].get( - qi + 1, {"n": 0.0, "p_sum": 0.0, "y_sum": 0.0}) - acc["n"] += float(nb) - acc["p_sum"] += float(nb) * pm - acc["y_sum"] += float(nb) * yr - pooled[g][qi + 1] = acc - group_n_total[g] = max(group_n_total[g], int(p.size)) - - for g in ["short", "medium", "long"]: - bvals = group_vals[g]["brier"] - ivals = group_vals[g]["ici"] - cal_group_sum_rows.append( - { - "model_id": model_id, - "model_type": model_type, - "loss_type": loss_type_id, - "age_encoder": age_encoder, - "cov_type": cov_type, - "cause": int(cause_id), - "sex": sex_label, - "horizon_group": g, - "brier_mean": float(np.mean(bvals)) if bvals else float("nan"), - "brier_median": float(np.median(bvals)) if bvals else float("nan"), - "ici_mean": float(np.mean(ivals)) if ivals else float("nan"), - "ici_median": float(np.median(ivals)) if ivals else float("nan"), - "n_total": int(group_n_total[g]), - "horizon_grouping_method": hg_method, - } - ) - - for qi in range(1, q_pool + 1): - acc = pooled[g].get(qi) - if not acc or float(acc.get("n", 0.0)) <= 0: - continue - n_bin = float(acc["n"]) - cal_group_bins_rows.append( - { - "model_id": model_id, - "model_type": model_type, - "loss_type": loss_type_id, - "age_encoder": age_encoder, - "cov_type": cov_type, - "cause": int(cause_id), - "sex": sex_label, - "horizon_group": g, - "q": int(qi), - "n_bin": int(n_bin), - "p_mean": float(acc["p_sum"] / n_bin), - "y_rate": float(acc["y_sum"] / n_bin), - "q_total": int(q_pool), - "horizon_grouping_method": hg_method, - } - ) - # Optionally write top-cause counts into the main results CSV as metric rows. for tc in top_causes_meta: model_rows.append( @@ -2141,46 +1910,6 @@ def main() -> int: write_calibration_bins_csv(calib_csv_path, calib_rows) # Write experiment exports - write_simple_csv( - os.path.join(export_dir, "risk_stratification_bins.csv"), - [ - "model_id", - "model_type", - "loss_type", - "age_encoder", - "cov_type", - "cause", - "horizon", - "sex", - "q", - "n_bin", - "p_mean", - "y_rate", - "y_overall", - "lift_vs_overall", - "q_total", - ], - rs_bins_rows, - ) - write_simple_csv( - os.path.join(export_dir, "risk_stratification_summary.csv"), - [ - "model_id", - "model_type", - "loss_type", - "age_encoder", - "cov_type", - "cause", - "horizon", - "sex", - "q_total", - "top_decile_y_rate", - "bottom_half_y_rate", - "lift_top10_vs_bottom50", - "slope_pred_vs_obs", - ], - rs_sum_rows, - ) write_simple_csv( os.path.join(export_dir, "lift_capture_points.csv"), [ @@ -2222,45 +1951,23 @@ def main() -> int: ], cap_curve_rows, ) + write_simple_csv( - os.path.join(export_dir, "calibration_groups_summary.csv"), + os.path.join(export_dir, "high_risk_cause_concentration.csv"), [ "model_id", "model_type", "loss_type", "age_encoder", "cov_type", - "cause", + "horizon", "sex", - "horizon_group", - "brier_mean", - "brier_median", - "ici_mean", - "ici_median", + "topk", + "event_rate_mean", + "event_rate_median", "n_total", - "horizon_grouping_method", ], - cal_group_sum_rows, - ) - write_simple_csv( - os.path.join(export_dir, "calibration_groups_bins.csv"), - [ - "model_id", - "model_type", - "loss_type", - "age_encoder", - "cov_type", - "cause", - "sex", - "horizon_group", - "q", - "n_bin", - "p_mean", - "y_rate", - "q_total", - "horizon_grouping_method", - ], - cal_group_bins_rows, + conc_rows, ) # Manifest markdown (stable, user-facing) @@ -2272,13 +1979,9 @@ def main() -> int: "All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n" "## Files\n\n" "- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n" - "- horizon_groups.csv: Mapping from each horizon to short/medium/long buckets. Intended plot: annotate calibration comparisons.\n" - "- risk_stratification_bins.csv: Quantile bins (deciles or quintiles) with predicted vs observed event rates and lift. Intended plot: reliability-by-risk-tier lines.\n" - "- risk_stratification_summary.csv: Compact stratification summaries (top decile vs bottom half lift, slope). Intended plot: slide-friendly comparison table.\n" "- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n" "- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n" - "- calibration_groups_summary.csv: Short/medium/long aggregated Brier/ICI (mean/median). Intended plot: grouped bar chart by horizon bucket.\n" - "- calibration_groups_bins.csv: Pooled reliability points per horizon bucket (weighted by bin size). Intended plot: 3-panel reliability curves per model.\n" + "- high_risk_cause_concentration.csv: Event Rate@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart of Event Rate@K vs K.\n" ) meta = { @@ -2293,12 +1996,11 @@ def main() -> int: "notes": { "label": "Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau])", "primary_metrics": "cause_brier (CIF-based) and cause_ici (calibration)", - "secondary_metrics": "cause_auc (discrimination) with optional CI", + "secondary_metrics": "cause_auc (discrimination)", "exclusions": "No all-cause aggregation; no next-event formulation; ECE not reported", "warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.", "exports_dir": export_dir, "focus_causes": focus_causes, - "horizon_grouping_method": hg_method, }, } with open(args.out_meta_json, "w") as f: