From 76d3fed76faaf85717f157ccd2f3c71f40ea97f8 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Mon, 19 Jan 2026 00:01:21 +0800 Subject: [PATCH] Enhance compute_disease_capture_at_k_fast to support return counts and update LandmarkEvaluator for backward compatibility with new metrics structure --- evaluate.py | 131 +++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 120 insertions(+), 11 deletions(-) diff --git a/evaluate.py b/evaluate.py index f49a7b3..2dda82e 100644 --- a/evaluate.py +++ b/evaluate.py @@ -94,7 +94,8 @@ def compute_disease_capture_at_k_fast( y_scores: np.ndarray, valid_mask: np.ndarray, top_k_list: List[int], -) -> Dict[int, Dict[int, float]]: + return_counts: bool = False, +) -> Dict[int, Dict[int, float]] | Tuple[Dict[int, Dict[int, float]], np.ndarray, Dict[int, np.ndarray]]: """Vectorized Disease-Capture@K. Definition: for each disease d, among valid positives (y_true==1 and valid_mask), @@ -116,7 +117,10 @@ def compute_disease_capture_at_k_fast( top_k_list = sorted({int(k) for k in top_k_list if int(k) > 0}) capture_rates: Dict[int, Dict[int, float]] = { int(k): {} for k in top_k_list} + hits_by_k: Dict[int, np.ndarray] = {} if N == 0 or K == 0 or len(top_k_list) == 0: + if return_counts: + return capture_rates, np.zeros((K,), dtype=np.int64), hits_by_k return capture_rates topk_max = min(max(top_k_list), K) @@ -145,6 +149,8 @@ def compute_disease_capture_at_k_fast( hits_mask = pos_valid[rows, idx_k] # (N, k_eff) bool hit_diseases = idx_k[hits_mask] hits = np.bincount(hit_diseases, minlength=K).astype(np.int64) + if return_counts: + hits_by_k[int(k_val)] = hits # Convert to dict with NaNs for diseases with no valid positives. out_k: Dict[int, float] = {} @@ -154,6 +160,8 @@ def compute_disease_capture_at_k_fast( out_k[d] = float(frac[d]) if denom[d] > 0 else float('nan') capture_rates[int(k_val)] = out_k + if return_counts: + return capture_rates, denom, hits_by_k return capture_rates @@ -197,6 +205,7 @@ def save_results_csv_bundle(results: Dict, out_dir: str) -> Dict[str, str]: summary_rows: List[Dict] = [] auc_rows: List[Dict] = [] capture_rows: List[Dict] = [] + capture_mean_rows: List[Dict] = [] lift_rows: List[Dict] = [] dca_rows: List[Dict] = [] @@ -233,18 +242,72 @@ def save_results_csv_bundle(results: Dict, out_dir: str) -> Dict[str, str]: if track == 'complete_case': capture = track_res.get('disease_capture_at_k') or {} - for k_val, per_disease in capture.items(): - k_int = int(k_val) + + # Backward-compatible parsing: + # - new format: {per_disease: {k: {d: rate}}, n_positive: {d: n}, macro_avg: {k: x}, micro_avg: {k: y}} + # - old format: {k: {d: rate}} + if isinstance(capture, dict) and 'per_disease' in capture: + per_disease_by_k = capture.get('per_disease') or {} + n_positive_by_disease = capture.get('n_positive') or {} + macro_by_k = capture.get('macro_avg') or {} + micro_by_k = capture.get('micro_avg') or {} + else: + per_disease_by_k = capture + n_positive_by_disease = {} + macro_by_k = {} + micro_by_k = {} + + # Per-disease rows (+ n_positive) + for k_val, per_disease in (per_disease_by_k or {}).items(): + try: + k_int = int(k_val) + except Exception: + continue for disease_idx, rate in (per_disease or {}).items(): + try: + d_int = int(disease_idx) + except Exception: + continue + n_pos = n_positive_by_disease.get( + d_int, n_positive_by_disease.get(str(d_int), np.nan)) capture_rows.append({ 'age_cutoff': age, 'horizon': horizon, 'track': track, 'k': k_int, - 'disease_idx': int(disease_idx), + 'disease_idx': d_int, 'capture_rate': _to_float(rate), + 'n_positive': _to_float(n_pos), }) + # Macro/Micro summary rows + # Prefer explicit macro/micro from the new format; otherwise compute macro from rates. + for k_val, per_disease in (per_disease_by_k or {}).items(): + try: + k_int = int(k_val) + except Exception: + continue + + macro = macro_by_k.get( + k_int, macro_by_k.get(str(k_int), np.nan)) + micro = micro_by_k.get( + k_int, micro_by_k.get(str(k_int), np.nan)) + if macro is None or (isinstance(macro, float) and np.isnan(macro)): + rates = [ + _to_float(r) for r in (per_disease or {}).values() + ] + macro = float(np.nanmean(rates)) if len( + rates) else np.nan + + capture_mean_rows.append({ + 'age_cutoff': age, + 'horizon': horizon, + 'track': track, + 'k': k_int, + 'macro_avg': _to_float(macro), + 'micro_avg': _to_float(micro), + }) + lift_yield = track_res.get('lift_and_yield') or {} overall = (lift_yield.get('overall') or {}) if isinstance( lift_yield, dict) else {} @@ -307,6 +370,11 @@ def save_results_csv_bundle(results: Dict, out_dir: str) -> Dict[str, str]: df_capture.to_csv(capture_path, index=False) paths['capture_at_k'] = capture_path + df_capture_mean = pd.DataFrame(capture_mean_rows) + capture_mean_path = os.path.join(out_dir, 'capture_at_k_mean.csv') + df_capture_mean.to_csv(capture_mean_path, index=False) + paths['capture_at_k_mean'] = capture_mean_path + df_lift = pd.DataFrame(lift_rows) lift_path = os.path.join(out_dir, 'lift_yield.csv') df_lift.to_csv(lift_path, index=False) @@ -921,7 +989,7 @@ class LandmarkEvaluator: risk_scores: np.ndarray, labels: np.ndarray, valid_mask: np.ndarray, - ) -> Dict[int, Dict[int, float]]: + ) -> Dict[str, Dict]: """ Compute Disease-Capture@K: fraction of true positives where the true disease appears in the patient's top-K predicted risks. @@ -932,15 +1000,21 @@ class LandmarkEvaluator: valid_mask: (N, K) boolean mask Returns: - capture_rates: Dict[K_value][disease_idx] -> capture rate + metrics: Dict with keys: + - per_disease: Dict[k][disease_idx] -> capture rate + - n_positive: Dict[disease_idx] -> number of valid positives (support) + - n_captured: Dict[k][disease_idx] -> number of captured positives + - macro_avg: Dict[k] -> macro-average capture rate + - micro_avg: Dict[k] -> micro-average capture rate """ # Fast path (vectorized): compute top-k_max once per sample. t0 = time.perf_counter() if self.profile_metrics else None - capture_fast = compute_disease_capture_at_k_fast( + capture_rates, denom, hits_by_k = compute_disease_capture_at_k_fast( y_true=labels, y_scores=risk_scores, valid_mask=valid_mask, top_k_list=self.top_k_values, + return_counts=True, ) if self.profile_metrics and t0 is not None: dt = time.perf_counter() - t0 @@ -981,7 +1055,36 @@ class LandmarkEvaluator: f" [check] capture@K ok on subset (N={n_sub}). slow={t2 - t1:.3f}s fast={t3 - t2:.3f}s" ) - return capture_fast + K = int(risk_scores.shape[1]) + n_positive: Dict[int, int] = {int(d): int(denom[d]) for d in range(K)} + n_captured: Dict[int, Dict[int, int]] = {} + macro_avg: Dict[int, float] = {} + micro_avg: Dict[int, float] = {} + + total_pos = int(denom.sum()) + for k_val in self.top_k_values: + k_int = int(k_val) + hits = hits_by_k.get(k_int, np.zeros((K,), dtype=np.int64)) + n_captured[k_int] = {int(d): int(hits[d]) for d in range(K)} + + # Macro: mean across diseases with support (ignore NaNs) + rates = capture_rates.get(k_int, {}) + rate_values = np.array([rates.get(d, np.nan) + for d in range(K)], dtype=np.float64) + macro_avg[k_int] = float(np.nanmean( + rate_values)) if rate_values.size else float('nan') + + # Micro: sum captured / sum positives + micro_avg[k_int] = float( + hits.sum() / total_pos) if total_pos > 0 else float('nan') + + return { + 'per_disease': capture_rates, + 'n_positive': n_positive, + 'n_captured': n_captured, + 'macro_avg': macro_avg, + 'micro_avg': micro_avg, + } def _compute_disease_capture_at_k_slow( self, @@ -1627,9 +1730,15 @@ def print_summary(results: Dict): if 'disease_capture_at_k' in cc: print(f" Disease Capture:") for k in [5, 10, 20, 50]: - if k in cc['disease_capture_at_k']: - rates = list(cc['disease_capture_at_k'][k].values()) - mean_rate = np.nanmean(rates) + capture = cc.get('disease_capture_at_k') or {} + if isinstance(capture, dict) and 'macro_avg' in capture: + macro = (capture.get('macro_avg') or {}).get(k, np.nan) + micro = (capture.get('micro_avg') or {}).get(k, np.nan) + print( + f" Top-{k}: macro={_to_float(macro):.3f}, micro={_to_float(micro):.3f}") + elif isinstance(capture, dict) and k in capture: + rates = list((capture.get(k) or {}).values()) + mean_rate = np.nanmean([_to_float(r) for r in rates]) print(f" Top-{k}: {mean_rate:.3f}") # Show lift and yield