From fcd948818c4f2b402385687d8002b075ae61cf3c Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sat, 17 Jan 2026 15:31:12 +0800 Subject: [PATCH] Enhance next-event evaluation with age-bin metrics and diagnostic AUC outputs --- evaluate_horizon.py | 372 ++++++++++++++++++++++++++++------------- evaluate_next_event.py | 159 ++++++++++-------- 2 files changed, 348 insertions(+), 183 deletions(-) diff --git a/evaluate_horizon.py b/evaluate_horizon.py index ff4392d..9c9257c 100644 --- a/evaluate_horizon.py +++ b/evaluate_horizon.py @@ -1,18 +1,25 @@ -"""Horizon-capture evaluation. +"""Horizon-capture evaluation (event-driven, age-stratified). -DISCLAIMERS (important): -- The reported AUC is "time-dependent" only because the label depends on the chosen horizon $\tau$. - Without explicit censoring / follow-up end times, this is NOT a classical risk-set AUC with IPCW. - Use it for model comparison and diagnostics, not strict statistical interpretation. +This script implements the protocol described in 评估方案.md: -- The reported Brier scores are unadjusted diagnostic/proxy metrics (no censoring adjustment). - Use them to detect probability-mass compression / numerical stability issues; do not claim - calibrated absolute risk. +- Age-stratified evaluation: metrics are computed independently within each age bin (no mixing). +- Event-driven inclusion: each (person, age_bin) yields a record iff DOA <= bin upper bound and + there is at least one disease event in the bin; baseline t0 is sampled randomly from in-bin + disease events with t0 >= DOA. +- No follow-up completeness filtering (no t0+tau <= t_end constraint). + +Primary outputs per age bin: +- Top-K Event Capture@tau (event-count based) +- Workload–Yield curves (Top-p% people by a person-level horizon score) + +Secondary (diagnostic-only) outputs per age bin: +- Approximate event-driven AUC / Brier (no IPCW, no censoring adjustment) """ import argparse +import math import os -from typing import Dict, List, Sequence +from typing import Dict, List, Sequence, Tuple import numpy as np import pandas as pd @@ -40,6 +47,7 @@ from utils import ( roc_auc_ovr, seed_everything, topk_indices, + DAYS_PER_YEAR, ) @@ -58,8 +66,8 @@ def parse_args() -> argparse.Namespace: "--age_bins", type=str, nargs="+", - default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"], - help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)", + default=["40", "45", "50", "55", "60", "65", "70", "inf"], + help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)", ) p.add_argument( @@ -75,7 +83,14 @@ def parse_args() -> argparse.Namespace: "--topk_list", type=int, nargs="+", - default=[1, 5, 10, 20, 50], + default=[5, 10, 20, 50], + ) + p.add_argument( + "--workload_fracs", + type=float, + nargs="+", + default=[0.01, 0.02, 0.05, 0.1, 0.2, 0.5], + help="Fractions for workload–yield curves (Top-p%% people).", ) p.add_argument( "--no_tqdm", @@ -85,6 +100,33 @@ def parse_args() -> argparse.Namespace: return p.parse_args() +def _format_age_bin_label(lo: float, hi: float) -> str: + if np.isinf(hi): + return f"[{lo}, inf)" + return f"[{lo}, {hi})" + + +def _assign_age_bin_idx(t0_days: np.ndarray, age_bins_years: Sequence[float]) -> np.ndarray: + age_bins_years = np.asarray(list(age_bins_years), dtype=np.float64) + age_bins_days = age_bins_years * DAYS_PER_YEAR + return np.searchsorted(age_bins_days, t0_days, side="left") - 1 + + +def _event_counts_within_tau( + n_records: int, + event_record_idx: np.ndarray, + event_dt_years: np.ndarray, + tau_years: float, +) -> np.ndarray: + """Count events within (t0, t0+tau] per record (event-count, not unique causes).""" + if event_record_idx.size == 0: + return np.zeros((n_records,), dtype=np.int64) + m = event_dt_years <= float(tau_years) + if not np.any(m): + return np.zeros((n_records,), dtype=np.int64) + return np.bincount(event_record_idx[m], minlength=n_records).astype(np.int64) + + def build_labels_within_tau_flat( n_records: int, n_causes: int, @@ -148,8 +190,8 @@ def main() -> None: ) # Print disclaimers every run (requested) - print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).") - print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).") + print("PRIMARY METRICS: event-count Capture@K and Workload–Yield, computed independently per age bin.") + print("DIAGNOSTICS ONLY: AUC/Brier below are event-driven approximations (no IPCW / censoring adjustment).") scores = predict_cifs( model, @@ -171,131 +213,235 @@ def main() -> None: raise ValueError("Record count mismatch") # Pre-flatten all future events once to avoid repeated per-record scans. + # NOTE: these are event-level arrays (not unique causes), suitable for event-count Capture@K. evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K) - per_tau_rows: List[Dict[str, object]] = [] - per_cause_rows: List[Dict[str, object]] = [] + # Assign each record to an age bin (based on t0; by construction t0 is within the bin). + t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64) + bin_idx = _assign_age_bin_idx(t0_days, age_bins_years) + age_bins_years_arr = np.asarray(list(age_bins_years), dtype=np.float64) + + capture_rows: List[Dict[str, object]] = [] workload_rows: List[Dict[str, object]] = [] - horizon_iter = enumerate(horizons) + # Diagnostics (optional): approximate event-driven AUC/Brier computed per bin. + diag_rows: List[Dict[str, object]] = [] + diag_per_cause_parts: List[pd.DataFrame] = [] + + bins_iter = range(len(age_bins_years_arr) - 1) if show_progress and tqdm is not None: - horizon_iter = tqdm(horizon_iter, total=len( - horizons), desc="Metrics by horizon") + bins_iter = tqdm(bins_iter, total=len( + age_bins_years_arr) - 1, desc="Age bins") - for h_idx, tau in horizon_iter: - s_tau = scores[:, :, h_idx] - y_tau = build_labels_within_tau_flat( - N, K, evt_rec_idx, evt_cause, evt_dt, tau) + for b in bins_iter: + lo = float(age_bins_years_arr[b]) + hi = float(age_bins_years_arr[b + 1]) + age_label = _format_age_bin_label(lo, hi) - # Per-cause counts + Brier (vectorized) - n_pos = y_tau.sum(axis=0).astype(np.int64) - n_neg = (int(N) - n_pos).astype(np.int64) + m_rec = bin_idx == b + n_bin = int(m_rec.sum()) + if n_bin == 0: + continue - # Brier per cause: mean_i (y - s)^2 - brier_per_cause = np.mean( - (y_tau.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0) - brier_macro = float(np.mean(brier_per_cause)) if K > 0 else float("nan") - brier_weighted = float(np.sum( - brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan") + rec_idx_bin = np.flatnonzero(m_rec).astype(np.int32) - # AUC: compute only for causes with enough positives and at least 1 negative - auc = np.full((K,), np.nan, dtype=np.float64) - min_pos = int(args.min_pos) - candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0)) - for k in candidates: - auc[k] = roc_auc_ovr(y_tau[:, k].astype( - np.int32), s_tau[:, k].astype(np.float64)) + # Filter events to this bin's records once. + m_evt_bin = m_rec[evt_rec_idx] if evt_rec_idx.size > 0 else np.zeros( + (0,), dtype=bool) + evt_rec_idx_b = evt_rec_idx[m_evt_bin] + evt_cause_b = evt_cause[m_evt_bin] + evt_dt_b = evt_dt[m_evt_bin] - finite_auc = auc[np.isfinite(auc)] - auc_macro = float(np.mean(finite_auc) - ) if finite_auc.size > 0 else float("nan") - w_mask = np.isfinite(auc) - auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum( - n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan") - n_valid_auc = int(np.isfinite(auc).sum()) + horizon_iter = enumerate(horizons) + if show_progress and tqdm is not None: + horizon_iter = tqdm(horizon_iter, total=len( + horizons), desc=f"Horizons {age_label}") - # Append per-cause rows (vectorized via DataFrame to avoid Python loops) - per_cause_rows.append( - pd.DataFrame( - { - "tau_years": float(tau), - "cause_id": np.arange(K, dtype=np.int64), - "n_pos": n_pos, - "n_neg": n_neg, - "auc": auc, - "brier": brier_per_cause, - } + # Precompute a local index mapping for diagnostics label building. + local_map = np.full((N,), -1, dtype=np.int32) + local_map[rec_idx_bin] = np.arange(n_bin, dtype=np.int32) + + for h_idx, tau in horizon_iter: + s_tau_all = scores[:, :, h_idx] + s_tau = s_tau_all[m_rec] + + # ------------------------- + # Primary metric: Top-K Event Capture@tau (event-count based) + # ------------------------- + denom_events = int(np.sum(evt_dt_b <= float(tau)) + ) if evt_dt_b.size > 0 else 0 + if denom_events == 0: + for topk in args.topk_list: + capture_rows.append( + { + "age_bin": age_label, + "tau_years": float(tau), + "topk": int(topk), + "capture_at_k": float("nan"), + "denom_events": int(0), + "numer_events": int(0), + "n_records": int(n_bin), + "n_causes": int(K), + } + ) + else: + m_evt_tau = evt_dt_b <= float(tau) + evt_rec_idx_tau = evt_rec_idx_b[m_evt_tau] + evt_cause_tau = evt_cause_b[m_evt_tau] + + # For each K, compute whether each event's cause is in that record's Top-K list. + for topk in args.topk_list: + topk = int(topk) + idx = topk_indices(s_tau_all, topk) # shape (N, topk) + idx_for_events = idx[evt_rec_idx_tau] + hits = (idx_for_events == + evt_cause_tau[:, None]).any(axis=1) + numer_events = int(hits.sum()) + capture = float(numer_events / denom_events) + capture_rows.append( + { + "age_bin": age_label, + "tau_years": float(tau), + "topk": int(topk), + "capture_at_k": capture, + "denom_events": int(denom_events), + "numer_events": int(numer_events), + "n_records": int(n_bin), + "n_causes": int(K), + } + ) + + # ------------------------- + # Primary metric: Workload–Yield (Top-p% people) + # ------------------------- + # Person-level score: max_k CIF_k(tau). This is used only for workload–yield ranking. + person_score = s_tau.max(axis=1) if K > 0 else np.zeros( + (n_bin,), dtype=np.float64) + order = np.argsort(-person_score, kind="mergesort") + + counts_per_record = _event_counts_within_tau( + n_bin, local_map[evt_rec_idx_b], evt_dt_b, tau) + total_events = int(counts_per_record.sum()) + overall_events_per_person = ( + total_events / float(n_bin)) if n_bin > 0 else float("nan") + + for frac in args.workload_fracs: + frac = float(frac) + if frac <= 0.0: + continue + n_sel = int(math.ceil(frac * n_bin)) + n_sel = min(max(n_sel, 1), n_bin) + sel_local = order[:n_sel] + events_captured = int(counts_per_record[sel_local].sum()) + capture_rate = float( + events_captured / total_events) if total_events > 0 else float("nan") + + selected_events_per_person = ( + events_captured / float(n_sel)) if n_sel > 0 else float("nan") + lift = (selected_events_per_person / + overall_events_per_person) if overall_events_per_person > 0 else float("nan") + + workload_rows.append( + { + "age_bin": age_label, + "tau_years": float(tau), + "frac_selected": float(frac), + "n_selected": int(n_sel), + "n_records": int(n_bin), + "total_events": int(total_events), + "events_captured": int(events_captured), + "capture_rate": capture_rate, + "lift_events_per_person": float(lift), + "person_score_def": "max_k_CIF_k(tau)", + } + ) + + # ------------------------- + # Diagnostics (optional): approximate event-driven AUC/Brier + # ------------------------- + # Convert event-level data to binary labels y[i,k]=1 iff >=1 event of cause k within tau. + y_tau_bin = np.zeros((n_bin, K), dtype=np.int8) + if evt_dt_b.size > 0: + m_evt_tau = evt_dt_b <= float(tau) + if np.any(m_evt_tau): + rec_local = local_map[evt_rec_idx_b[m_evt_tau]] + valid = rec_local >= 0 + y_tau_bin[rec_local[valid], + evt_cause_b[m_evt_tau][valid]] = 1 + + n_pos = y_tau_bin.sum(axis=0).astype(np.int64) + n_neg = (int(n_bin) - n_pos).astype(np.int64) + + brier_per_cause = np.mean( + (y_tau_bin.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0 ) - ) + brier_macro = float(np.mean(brier_per_cause) + ) if K > 0 else float("nan") + brier_weighted = float(np.sum( + brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan") - # Business metrics for each topK - denom_true_pairs = int(y_tau.sum()) - for topk in args.topk_list: - topk = int(topk) - idx = topk_indices(s_tau, topk) - captured = np.take_along_axis(y_tau, idx, axis=1) - hits = captured.sum(axis=1).astype(np.float64) - true_cnt = y_tau.sum(axis=1).astype(np.float64) + auc = np.full((K,), np.nan, dtype=np.float64) + min_pos = int(args.min_pos) + candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0)) + for k in candidates: + auc[k] = roc_auc_ovr(y_tau_bin[:, k].astype( + np.int32), s_tau[:, k].astype(np.float64)) - precision_like = hits / float(min(topk, K)) - mean_precision = float(np.mean(precision_like) - ) if N > 0 else float("nan") + finite_auc = auc[np.isfinite(auc)] + auc_macro = float(np.mean(finite_auc) + ) if finite_auc.size > 0 else float("nan") + w_mask = np.isfinite(auc) + auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum( + n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan") + n_valid_auc = int(np.isfinite(auc).sum()) - mask_has_true = true_cnt > 0 - recall_like = np.full((N,), np.nan, dtype=np.float64) - recall_like[mask_has_true] = hits[mask_has_true] / \ - true_cnt[mask_has_true] - mean_recall = float(np.nanmean(recall_like)) if np.any( - mask_has_true) else float("nan") - median_recall = float(np.nanmedian(recall_like)) if np.any( - mask_has_true) else float("nan") - - numer_captured_pairs = int(captured.sum()) - pop_capture_rate = float( - numer_captured_pairs / denom_true_pairs) if denom_true_pairs > 0 else float("nan") - - workload_rows.append( + diag_rows.append( { + "age_bin": age_label, "tau_years": float(tau), - "topk": int(topk), - "population_capture_rate": pop_capture_rate, - "mean_precision_like": mean_precision, - "mean_recall_like": mean_recall, - "median_recall_like": median_recall, - "denom_true_pairs": denom_true_pairs, - "numer_captured_pairs": numer_captured_pairs, + "n_records": int(n_bin), + "n_causes": int(K), + "auc_macro": auc_macro, + "auc_weighted_by_npos": auc_weighted, + "n_causes_valid_auc": int(n_valid_auc), + "brier_macro": brier_macro, + "brier_weighted_by_npos": brier_weighted, } ) - per_tau_rows.append( - { - "tau_years": float(tau), - "n_records": int(N), - "n_causes": int(K), - "auc_macro": auc_macro, - "auc_weighted_by_npos": auc_weighted, - "n_causes_valid_auc": int(n_valid_auc), - "brier_macro": brier_macro, - "brier_weighted_by_npos": brier_weighted, - "total_true_pairs": denom_true_pairs, - } - ) + diag_per_cause_parts.append( + pd.DataFrame( + { + "age_bin": age_label, + "tau_years": float(tau), + "cause_id": np.arange(K, dtype=np.int64), + "n_pos": n_pos, + "n_neg": n_neg, + "auc": auc, + "brier": brier_per_cause, + } + ) + ) - out_metrics = os.path.join(run_dir, "horizon_metrics.csv") - out_pc = os.path.join(run_dir, "horizon_per_cause.csv") + out_capture = os.path.join(run_dir, "horizon_capture.csv") out_wy = os.path.join(run_dir, "workload_yield.csv") + out_diag = os.path.join(run_dir, "horizon_metrics.csv") + out_diag_pc = os.path.join(run_dir, "horizon_per_cause.csv") - pd.DataFrame(per_tau_rows).to_csv(out_metrics, index=False) - if per_cause_rows: - pd.concat(per_cause_rows, ignore_index=True).to_csv(out_pc, index=False) - else: - pd.DataFrame(columns=["tau_years", "cause_id", "n_pos", - "n_neg", "auc", "brier"]).to_csv(out_pc, index=False) + pd.DataFrame(capture_rows).to_csv(out_capture, index=False) pd.DataFrame(workload_rows).to_csv(out_wy, index=False) + pd.DataFrame(diag_rows).to_csv(out_diag, index=False) + if diag_per_cause_parts: + pd.concat(diag_per_cause_parts, ignore_index=True).to_csv( + out_diag_pc, index=False) + else: + pd.DataFrame(columns=["age_bin", "tau_years", "cause_id", "n_pos", + "n_neg", "auc", "brier"]).to_csv(out_diag_pc, index=False) - print(f"Wrote {out_metrics}") - print(f"Wrote {out_pc}") + print(f"Wrote {out_capture}") print(f"Wrote {out_wy}") + print(f"Wrote {out_diag} (diagnostic-only)") + print(f"Wrote {out_diag_pc} (diagnostic-only)") if __name__ == "__main__": diff --git a/evaluate_next_event.py b/evaluate_next_event.py index 65741af..76bef5a 100644 --- a/evaluate_next_event.py +++ b/evaluate_next_event.py @@ -41,8 +41,8 @@ def parse_args() -> argparse.Namespace: "--age_bins", type=str, nargs="+", - default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"], - help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)", + default=["40", "45", "50", "55", "60", "65", "70", "inf"], + help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)", ) p.add_argument( @@ -80,9 +80,14 @@ def _compute_next_event_metrics( tau_short: float, min_pos: int, ) -> tuple[list[dict], pd.DataFrame]: - """Compute next-event metrics on a given subset. + """Compute next-event *primary* metrics on a given subset. - Definitions are unchanged from the original script. + Implements 评估方案.md (Next-event): + - score_k = CIF_k(tau_short) + - Hit@K / MRR are computed on records with an observed next-event. + + Returns (metrics_rows, diag_df). diag_df is a diagnostic per-cause AUC table + based on whether the cause occurs within (t0, t0+tau_short] (display-only). """ n_records_total = int(y_next.size) eligible = y_next >= 0 @@ -99,49 +104,40 @@ def _compute_next_event_metrics( {"metric": "tau_short_years", "value": float(tau_short)}) K = int(scores.shape[1]) + # Diagnostic: build per-cause AUC using within-window labels. + # This is NOT a primary metric (no IPCW / censoring adjustment). + diag_df = pd.DataFrame( + { + "cause_id": np.arange(K, dtype=np.int64), + "n_pos": np.zeros((K,), dtype=np.int64), + "n_neg": np.zeros((K,), dtype=np.int64), + "auc": np.full((K,), np.nan, dtype=np.float64), + "included": np.zeros((K,), dtype=bool), + } + ) if n_records_total == 0: - per_cause_df = pd.DataFrame( - { - "cause_id": np.arange(K, dtype=np.int64), - "n_pos": np.zeros((K,), dtype=np.int64), - "n_neg": np.zeros((K,), dtype=np.int64), - "auc": np.full((K,), np.nan, dtype=np.float64), - "included": np.zeros((K,), dtype=bool), - } - ) - metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")}) + metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")}) metrics_rows.append({"metric": "mrr", "value": float("nan")}) for k in [1, 3, 5, 10, 20]: metrics_rows.append( {"metric": f"hitrate_at_{k}", "value": float("nan")}) - metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) - return metrics_rows, per_cause_df + return metrics_rows, diag_df # If no eligible, keep coverage but leave accuracy-like metrics as NaN. if n_eligible == 0: - per_cause_df = pd.DataFrame( - { - "cause_id": np.arange(K, dtype=np.int64), - "n_pos": np.zeros((K,), dtype=np.int64), - "n_neg": np.full((K,), n_records_total, dtype=np.int64), - "auc": np.full((K,), np.nan, dtype=np.float64), - "included": np.zeros((K,), dtype=bool), - } - ) - metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")}) + metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")}) metrics_rows.append({"metric": "mrr", "value": float("nan")}) for k in [1, 3, 5, 10, 20]: metrics_rows.append( {"metric": f"hitrate_at_{k}", "value": float("nan")}) - metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) - return metrics_rows, per_cause_df + return metrics_rows, diag_df scores_e = scores[eligible] y_e = y_next[eligible] pred = scores_e.argmax(axis=1) acc = float((pred == y_e).mean()) - metrics_rows.append({"metric": "top1_accuracy", "value": acc}) + metrics_rows.append({"metric": "hitrate_at_1", "value": acc}) # MRR order = np.argsort(-scores_e, axis=1, kind="mergesort") @@ -158,17 +154,53 @@ def _compute_next_event_metrics( metrics_rows.append({"metric": f"hitrate_at_{k}", "value": float(hit.mean())}) - # Macro OvR AUC per cause (optional) - n_pos = np.bincount(y_e, minlength=K).astype(np.int64) - n_neg = (int(y_e.size) - n_pos).astype(np.int64) + # Diagnostic per-cause AUC is computed outside (needs future events), so keep placeholder here. + _ = min_pos + return metrics_rows, diag_df + + +def _compute_within_window_auc( + *, + scores: np.ndarray, + records: list, + tau_short: float, + min_pos: int, +) -> pd.DataFrame: + """Diagnostic-only per-cause AUC. + + Label definition (event-driven, approximate; no IPCW): + y[i,k]=1 iff at least one event of cause k occurs in (t0, t0+tau_short]. + """ + n_records = int(len(records)) + if n_records == 0: + return pd.DataFrame( + columns=["cause_id", "n_pos", "n_neg", "auc", "included"], + ) + + K = int(scores.shape[1]) + y = np.zeros((n_records, K), dtype=np.int8) + tau = float(tau_short) + + # Build labels from future events. + for i, r in enumerate(records): + if r.future_causes.size == 0: + continue + m = r.future_dt_years <= tau + if not np.any(m): + continue + y[i, r.future_causes[m]] = 1 + + n_pos = y.sum(axis=0).astype(np.int64) + n_neg = (int(n_records) - n_pos).astype(np.int64) auc = np.full((K,), np.nan, dtype=np.float64) candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0)) for k in candidates: - auc[k] = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k]) + auc[k] = roc_auc_ovr(y[:, k].astype(np.int32), + scores[:, k].astype(np.float64)) included = (n_pos >= int(min_pos)) & (n_neg > 0) - per_cause_df = pd.DataFrame( + return pd.DataFrame( { "cause_id": np.arange(K, dtype=np.int64), "n_pos": n_pos, @@ -178,15 +210,6 @@ def _compute_next_event_metrics( } ) - aucs = auc[np.isfinite(auc)] - if aucs.size > 0: - metrics_rows.append( - {"metric": "macro_ovr_auc", "value": float(np.mean(aucs))}) - else: - metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) - - return metrics_rows, per_cause_df - def main() -> None: args = parse_args() @@ -246,19 +269,7 @@ def main() -> None: ) # Overall (preserve existing output files/shape) - metrics_rows, per_cause_df = _compute_next_event_metrics( - scores=scores, - y_next=y_next, - tau_short=tau, - min_pos=int(args.min_pos), - ) - - out_metrics = os.path.join(run_dir, "next_event_metrics.csv") - pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False) - out_pc = os.path.join(run_dir, "next_event_per_cause.csv") - per_cause_df.to_csv(out_pc, index=False) - - # By age bin (new outputs) + # Strict protocol: evaluate independently per age bin (no mixing). age_bins_years = np.asarray(age_bins_years, dtype=np.float64) age_bins_days = age_bins_years * DAYS_PER_YEAR # Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1}) @@ -266,7 +277,7 @@ def main() -> None: bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1 per_bin_metric_rows: List[dict] = [] - per_bin_cause_parts: List[pd.DataFrame] = [] + per_bin_auc_parts: List[pd.DataFrame] = [] for b in range(len(age_bins_years) - 1): lo = float(age_bins_years[b]) hi = float(age_bins_years[b + 1]) @@ -274,6 +285,7 @@ def main() -> None: m = bin_idx == b m_scores = scores[m] m_y = y_next[m] + m_records = [r for r, keep in zip(records, m) if bool(keep)] m_rows, m_pc = _compute_next_event_metrics( scores=m_scores, y_next=m_y, @@ -282,25 +294,32 @@ def main() -> None: ) for row in m_rows: per_bin_metric_rows.append({"age_bin": label, **row}) - m_pc = m_pc.copy() - m_pc.insert(0, "age_bin", label) - per_bin_cause_parts.append(m_pc) + m_auc = _compute_within_window_auc( + scores=m_scores, + records=m_records, + tau_short=tau, + min_pos=int(args.min_pos), + ) + m_auc.insert(0, "age_bin", label) + m_auc.insert(1, "tau_short_years", float(tau)) + per_bin_auc_parts.append(m_auc) out_metrics_bins = os.path.join( run_dir, "next_event_metrics_by_age_bin.csv") pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False) - out_pc_bins = os.path.join(run_dir, "next_event_per_cause_by_age_bin.csv") - if per_bin_cause_parts: - pd.concat(per_bin_cause_parts, ignore_index=True).to_csv( - out_pc_bins, index=False) - else: - pd.DataFrame(columns=["age_bin", "cause_id", "n_pos", "n_neg", - "auc", "included"]).to_csv(out_pc_bins, index=False) - print(f"Wrote {out_metrics}") - print(f"Wrote {out_pc}") + out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv") + if per_bin_auc_parts: + pd.concat(per_bin_auc_parts, ignore_index=True).to_csv( + out_auc_bins, index=False) + else: + pd.DataFrame(columns=["age_bin", "tau_short_years", "cause_id", "n_pos", + "n_neg", "auc", "included"]).to_csv(out_auc_bins, index=False) + + print("PRIMARY METRICS: Hit@K / MRR are reported per age bin.") + print("DIAGNOSTICS ONLY: AUC table is event-driven approximate (no IPCW).") print(f"Wrote {out_metrics_bins}") - print(f"Wrote {out_pc_bins}") + print(f"Wrote {out_auc_bins}") if __name__ == "__main__":