From d8b322cbee29dbb0a0cdd413a5a536967870bb24 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sun, 11 Jan 2026 00:52:35 +0800 Subject: [PATCH] Enhance Event Rate@K and Recall@K computations with random ranking baseline and additional metrics --- evaluate_models.py | 189 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 185 insertions(+), 4 deletions(-) diff --git a/evaluate_models.py b/evaluate_models.py index a1165df..cf86d2e 100644 --- a/evaluate_models.py +++ b/evaluate_models.py @@ -943,7 +943,11 @@ def compute_event_rate_at_topk_causes( topk_list: list of K values to evaluate Returns: - List of rows with: topk, mean, median, n_total. + List of rows with: + - topk + - event_rate_mean / event_rate_median + - recall_mean / recall_median (averaged over individuals with >=1 true cause) + - n_total / n_valid_recall """ p = np.asarray(p_tau, dtype=float) y = np.asarray(y_tau, dtype=float) @@ -960,7 +964,10 @@ def compute_event_rate_at_topk_causes( "topk": int(max(1, int(kk))), "event_rate_mean": float("nan"), "event_rate_median": float("nan"), + "recall_mean": float("nan"), + "recall_median": float("nan"), "n_total": int(n), + "n_valid_recall": 0, } ) return out @@ -985,19 +992,140 @@ def compute_event_rate_at_topk_causes( kk_eff = min(int(kk), int(k_total)) idx = top_sorted[:, :kk_eff] y_sel = np.take_along_axis(y, idx, axis=1) - # fraction of selected causes that occur - per_person = np.mean(y_sel, axis=1) + # Selected true causes per person + hit = np.sum(y_sel, axis=1) + # Precision-like: fraction of selected causes that occur + per_person = hit / \ + float(kk_eff) if kk_eff > 0 else np.full((n,), np.nan) + + # Recall@K: fraction of true causes covered by top-K (undefined when no true cause) + g = np.sum(y, axis=1) + valid = g > 0 + recall = np.full((n,), np.nan, dtype=float) + recall[valid] = hit[valid] / g[valid] out_rows.append( { "topk": int(kk_eff), "event_rate_mean": float(np.mean(per_person)) if per_person.size else float("nan"), "event_rate_median": float(np.median(per_person)) if per_person.size else float("nan"), + "recall_mean": float(np.nanmean(recall)) if int(np.sum(valid)) > 0 else float("nan"), + "recall_median": float(np.nanmedian(recall)) if int(np.sum(valid)) > 0 else float("nan"), "n_total": int(n), + "n_valid_recall": int(np.sum(valid)), } ) return out_rows +def compute_random_ranking_baseline_topk( + y_tau: np.ndarray, + topk_list: Sequence[int], + *, + z: float = 1.645, +) -> List[Dict[str, Any]]: + """Random ranking baseline for Event Rate@K and Recall@K. + + Baseline definition: + - For each individual, pick K causes uniformly at random without replacement. + - EventRate@K = (# selected causes that occur) / K. + - Recall@K = (# selected causes that occur) / (# causes that occur), averaged over individuals with >=1 true cause. + + This function computes the expected baseline mean and an approximate 5-95% range + for the population mean using a normal approximation of the hypergeometric variance. + + Args: + y_tau: (N, K_total) binary labels + topk_list: K values + z: z-score for the central interval; z=1.645 corresponds to ~90% (5-95%) + + Returns: + Rows with baseline means and p05/p95 for both metrics. + """ + y = np.asarray(y_tau, dtype=float) + if y.ndim != 2: + raise ValueError( + "compute_random_ranking_baseline_topk expects y_tau with shape (N,K)") + + n, k_total = y.shape + topks = sorted({int(x) for x in topk_list if int(x) > 0}) + if not topks: + return [] + + g = np.sum(y, axis=1) # (N,) + valid = g > 0 + n_valid = int(np.sum(valid)) + + out: List[Dict[str, Any]] = [] + for kk in topks: + kk_eff = min(int(kk), int(k_total)) if k_total > 0 else int(kk) + if n == 0 or k_total == 0 or kk_eff <= 0: + out.append( + { + "topk": int(max(1, kk_eff)), + "baseline_event_rate_mean": float("nan"), + "baseline_event_rate_p05": float("nan"), + "baseline_event_rate_p95": float("nan"), + "baseline_recall_mean": float("nan"), + "baseline_recall_p05": float("nan"), + "baseline_recall_p95": float("nan"), + "n_total": int(n), + "n_valid_recall": int(n_valid), + "k_total": int(k_total), + "baseline_method": "random_ranking_hypergeometric_normal_approx", + } + ) + continue + + # Expected EventRate@K per person is E[X]/K = (K * (g/K_total))/K = g/K_total. + er_mean = float(np.mean(g / float(k_total))) + + # Variance of hypergeometric count X: + # Var(X) = K * p * (1-p) * ((K_total - K)/(K_total - 1)), where p=g/K_total. + if k_total > 1 and kk_eff < k_total: + p = g / float(k_total) + finite_corr = (float(k_total - kk_eff) / float(k_total - 1)) + var_x = float(kk_eff) * p * (1.0 - p) * finite_corr + else: + var_x = np.zeros_like(g, dtype=float) + + var_er = var_x / (float(kk_eff) ** 2) + se_er_mean = float(np.sqrt(np.sum(var_er))) / float(max(1, n)) + er_p05 = float(np.clip(er_mean - z * se_er_mean, 0.0, 1.0)) + er_p95 = float(np.clip(er_mean + z * se_er_mean, 0.0, 1.0)) + + # Expected Recall@K for individuals with g>0 is K/K_total (clipped). + rec_mean = float(min(float(kk_eff) / float(k_total), 1.0)) + if n_valid > 0: + var_rec = np.zeros_like(g, dtype=float) + gv = g[valid] + var_xv = var_x[valid] + # Var( X / g ) = Var(X) / g^2 (approx; g is fixed per individual) + var_rec_v = var_xv / (gv ** 2) + se_rec_mean = float(np.sqrt(np.sum(var_rec_v))) / float(n_valid) + rec_p05 = float(np.clip(rec_mean - z * se_rec_mean, 0.0, 1.0)) + rec_p95 = float(np.clip(rec_mean + z * se_rec_mean, 0.0, 1.0)) + else: + rec_p05 = float("nan") + rec_p95 = float("nan") + + out.append( + { + "topk": int(kk_eff), + "baseline_event_rate_mean": er_mean, + "baseline_event_rate_p05": er_p05, + "baseline_event_rate_p95": er_p95, + "baseline_recall_mean": rec_mean, + "baseline_recall_p05": float(rec_p05), + "baseline_recall_p95": float(rec_p95), + "n_total": int(n), + "n_valid_recall": int(n_valid), + "k_total": int(k_total), + "baseline_method": "random_ranking_hypergeometric_normal_approx", + } + ) + return out + + def count_occurs_within_horizon( loader: DataLoader, offset_years: float, @@ -1560,6 +1688,12 @@ def main() -> int: default=[5, 10, 20, 50], help="Top-K causes per individual for Event Rate@K (cross-cause prioritization)", ) + ap.add_argument( + "--cause_concentration_write_random_baseline", + action="store_true", + default=False, + help="If set, also export a random-ranking baseline (expected Event Rate@K and Recall@K with an uncertainty range)", + ) args = ap.parse_args() set_deterministic(args.seed) @@ -1645,6 +1779,7 @@ def main() -> int: cap_points_rows: List[Dict[str, Any]] = [] cap_curve_rows: List[Dict[str, Any]] = [] conc_rows: List[Dict[str, Any]] = [] + conc_base_rows: List[Dict[str, Any]] = [] # Track per-model integrity status for meta JSON. integrity_meta: Dict[str, Any] = {} @@ -1774,6 +1909,21 @@ def main() -> int: } ) + if bool(args.cause_concentration_write_random_baseline): + for rr in compute_random_ranking_baseline_topk(y_tau_all, topk_causes): + conc_base_rows.append( + { + "model_id": model_id, + "model_type": model_type, + "loss_type": loss_type_id, + "age_encoder": age_encoder, + "cov_type": cov_type, + "horizon": float(tau), + "sex": sex_label, + **rr, + } + ) + # Convenience slices for user-facing experiments (focus causes only). cause_cif_focus = cif_full[:, top_cause_ids, :] y_within_focus = y_cause_within_tau[:, top_cause_ids, :] @@ -1965,11 +2115,41 @@ def main() -> int: "topk", "event_rate_mean", "event_rate_median", + "recall_mean", + "recall_median", "n_total", + "n_valid_recall", ], conc_rows, ) + if conc_base_rows: + write_simple_csv( + os.path.join( + export_dir, "high_risk_cause_concentration_random_baseline.csv"), + [ + "model_id", + "model_type", + "loss_type", + "age_encoder", + "cov_type", + "horizon", + "sex", + "topk", + "baseline_event_rate_mean", + "baseline_event_rate_p05", + "baseline_event_rate_p95", + "baseline_recall_mean", + "baseline_recall_p05", + "baseline_recall_p95", + "n_total", + "n_valid_recall", + "k_total", + "baseline_method", + ], + conc_base_rows, + ) + # Manifest markdown (stable, user-facing) manifest_path = os.path.join(export_dir, "eval_exports_manifest.md") with open(manifest_path, "w", encoding="utf-8") as f: @@ -1981,7 +2161,8 @@ def main() -> int: "- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n" "- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n" "- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n" - "- high_risk_cause_concentration.csv: Event Rate@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart of Event Rate@K vs K.\n" + "- high_risk_cause_concentration.csv: Event Rate@K and Recall@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart vs K.\n" + "- high_risk_cause_concentration_random_baseline.csv (optional): Random-ranking baseline for Event Rate@K and Recall@K with an uncertainty range (enabled by --cause_concentration_write_random_baseline).\n" ) meta = {