From 6de2132e84042498e8d443896f451b9f01abf167 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sun, 18 Jan 2026 18:14:45 +0800 Subject: [PATCH] Add fast vectorized implementation for Disease-Capture@K and enhance LandmarkEvaluator with profiling and correctness check options --- evaluate.py | 204 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 176 insertions(+), 28 deletions(-) diff --git a/evaluate.py b/evaluate.py index fb2ca66..6109b32 100644 --- a/evaluate.py +++ b/evaluate.py @@ -11,6 +11,7 @@ Implements the comprehensive evaluation framework defined in evaluate_design.md: import argparse import json import os +import time from pathlib import Path from typing import Dict, List, Tuple, Optional import warnings @@ -87,6 +88,74 @@ def _to_float(x): return np.nan +def compute_disease_capture_at_k_fast( + y_true: np.ndarray, + y_scores: np.ndarray, + valid_mask: np.ndarray, + top_k_list: List[int], +) -> Dict[int, Dict[int, float]]: + """Vectorized Disease-Capture@K. + + Definition: for each disease d, among valid positives (y_true==1 and valid_mask), + capture@K is the fraction whose predicted top-K diseases contain d. + + This implementation avoids per-positive full argsorts by computing top-k_max once + per sample (using argpartition), sorting those k_max indices by score, and then + aggregating hits via bincount. + """ + + if y_scores.ndim != 2: + raise ValueError( + f"Expected y_scores 2D (N,K), got shape={y_scores.shape}") + if y_true.shape != y_scores.shape or valid_mask.shape != y_scores.shape: + raise ValueError( + f"Shape mismatch: y_true={y_true.shape}, y_scores={y_scores.shape}, valid_mask={valid_mask.shape}") + + N, K = y_scores.shape + top_k_list = sorted({int(k) for k in top_k_list if int(k) > 0}) + capture_rates: Dict[int, Dict[int, float]] = { + int(k): {} for k in top_k_list} + if N == 0 or K == 0 or len(top_k_list) == 0: + return capture_rates + + topk_max = min(max(top_k_list), K) + + # Valid positives per disease are the denominator. + pos_valid = (y_true == 1) & valid_mask.astype(bool) + denom = pos_valid.sum(axis=0).astype(np.int64) # (K,) + + # Compute top-k_max indices per sample once (unordered), then sort those indices by score. + part = np.argpartition(y_scores, -topk_max, + axis=1)[:, -topk_max:] # (N, topk_max) + part_scores = np.take_along_axis(y_scores, part, axis=1) + order = np.argsort(part_scores, axis=1)[:, ::-1] + topk_idx = np.take_along_axis( + part, order, axis=1).astype(np.int32) # (N, topk_max) + + rows = np.arange(N)[:, None] + + for k_val in top_k_list: + k_eff = min(int(k_val), topk_max) + idx_k = topk_idx[:, :k_eff] # (N, k_eff) + + # For each sample, we count a "hit" for disease d when: + # d is in top-K (true by construction for idx_k) + # AND sample is a valid positive for disease d. + hits_mask = pos_valid[rows, idx_k] # (N, k_eff) bool + hit_diseases = idx_k[hits_mask] + hits = np.bincount(hit_diseases, minlength=K).astype(np.int64) + + # Convert to dict with NaNs for diseases with no valid positives. + out_k: Dict[int, float] = {} + with np.errstate(divide='ignore', invalid='ignore'): + frac = hits / denom + for d in range(K): + out_k[d] = float(frac[d]) if denom[d] > 0 else float('nan') + capture_rates[int(k_val)] = out_k + + return capture_rates + + def save_summary_json(summary: Dict, output_path: str) -> None: """Save a single JSON summary file.""" @@ -266,6 +335,9 @@ class LandmarkEvaluator: batch_size: int = 256, num_workers: int = 4, compile_model: bool = True, + check_capture_at_k: bool = False, + profile_metrics: bool = False, + capture_check_n: int = 200, ): self.model = model.to(device).eval() self.head = head.to(device).eval() @@ -276,6 +348,11 @@ class LandmarkEvaluator: self.batch_size = batch_size self.num_workers = num_workers + self.check_capture_at_k = bool(check_capture_at_k) + self.profile_metrics = bool(profile_metrics) + self.capture_check_n = int(capture_check_n) + self._did_capture_check = False + use_cuda = str(self.device).startswith( "cuda") and torch.cuda.is_available() if use_cuda: @@ -775,27 +852,28 @@ class LandmarkEvaluator: Returns: auc_scores: Dict mapping disease_idx -> AUC """ - auc_scores = {} + auc_scores: Dict[int, float] = {} n_diseases = risk_scores.shape[1] + m = valid_mask.astype(bool) + + # Fast pre-filter: skip diseases without both classes among valid entries. + pos = (labels == 1) & m + n_valid = m.sum(axis=0) + n_pos = pos.sum(axis=0) + n_neg = n_valid - n_pos for k in range(n_diseases): - mk = valid_mask[:, k] - if not np.any(mk): + if n_valid[k] == 0 or n_pos[k] == 0 or n_neg[k] == 0: auc_scores[k] = np.nan continue - y_true = labels[mk, k] - y_score = risk_scores[mk, k] - - # Check if we have both classes - if len(np.unique(y_true)) < 2: + mk = m[:, k] + y_true_k = labels[mk, k] + y_score_k = risk_scores[mk, k] + try: + auc_scores[k] = float(roc_auc_score(y_true_k, y_score_k)) + except Exception: auc_scores[k] = np.nan - else: - try: - auc = roc_auc_score(y_true, y_score) - auc_scores[k] = auc - except Exception: - auc_scores[k] = np.nan return auc_scores @@ -855,39 +933,89 @@ class LandmarkEvaluator: Returns: capture_rates: Dict[K_value][disease_idx] -> capture rate """ - capture_rates = {k: {} for k in self.top_k_values} + # Fast path (vectorized): compute top-k_max once per sample. + t0 = time.perf_counter() if self.profile_metrics else None + capture_fast = compute_disease_capture_at_k_fast( + y_true=labels, + y_scores=risk_scores, + valid_mask=valid_mask, + top_k_list=self.top_k_values, + ) + if self.profile_metrics and t0 is not None: + dt = time.perf_counter() - t0 + print( + f" [profile] capture@K fast: {dt:.3f}s (N={risk_scores.shape[0]}, K={risk_scores.shape[1]})") + # Optional correctness check against the slow reference on a small subset. + if self.check_capture_at_k and (not self._did_capture_check): + self._did_capture_check = True + rng = np.random.default_rng(0) + n_sub = min(self.capture_check_n, risk_scores.shape[0]) + sub_idx = rng.choice( + risk_scores.shape[0], size=n_sub, replace=False) + rs = risk_scores[sub_idx] + ys = labels[sub_idx] + vm = valid_mask[sub_idx] + + t1 = time.perf_counter() + slow = self._compute_disease_capture_at_k_slow(rs, ys, vm) + t2 = time.perf_counter() + fast = compute_disease_capture_at_k_fast( + ys, rs, vm, self.top_k_values) + t3 = time.perf_counter() + + def _eq(a: float, b: float) -> bool: + if np.isnan(a) and np.isnan(b): + return True + return float(a) == float(b) + + for k_val in self.top_k_values: + for d in range(rs.shape[1]): + if not _eq(slow[int(k_val)][d], fast[int(k_val)][d]): + raise AssertionError( + f"Capture@{k_val} mismatch for disease {d}: slow={slow[int(k_val)][d]} fast={fast[int(k_val)][d]}" + ) + + print( + f" [check] capture@K ok on subset (N={n_sub}). slow={t2 - t1:.3f}s fast={t3 - t2:.3f}s" + ) + + return capture_fast + + def _compute_disease_capture_at_k_slow( + self, + risk_scores: np.ndarray, + labels: np.ndarray, + valid_mask: np.ndarray, + ) -> Dict[int, Dict[int, float]]: + """Reference implementation (slow): kept for correctness checking.""" + capture_rates = {int(k): {} for k in self.top_k_values} n_diseases = risk_scores.shape[1] for disease_idx in range(n_diseases): mk = valid_mask[:, disease_idx] if not np.any(mk): for k in self.top_k_values: - capture_rates[k][disease_idx] = np.nan + capture_rates[int(k)][disease_idx] = np.nan continue y_true = labels[mk, disease_idx] y_scores = risk_scores[mk] # (N_valid_k, K) - # Find patients with positive label for this disease pos_mask = y_true == 1 if pos_mask.sum() == 0: for k in self.top_k_values: - capture_rates[k][disease_idx] = np.nan + capture_rates[int(k)][disease_idx] = np.nan continue - # For each positive patient, check if true disease is in top-K + pos_idx = np.where(pos_mask)[0] for k_val in self.top_k_values: captures = [] - for i in np.where(pos_mask)[0]: - # Get top-K disease indices for this patient - top_k_diseases = np.argsort(y_scores[i])[::-1][:k_val] - # Check if true disease is in top-K - is_captured = disease_idx in top_k_diseases - captures.append(int(is_captured)) - - capture_rate = np.mean(captures) if captures else np.nan - capture_rates[k_val][disease_idx] = capture_rate + for i in pos_idx: + top_k_diseases = np.argsort(y_scores[i])[::-1][:int(k_val)] + captures.append(int(disease_idx in top_k_diseases)) + capture_rates[int(k_val)][disease_idx] = float( + np.mean(captures)) if captures else np.nan return capture_rates @@ -1558,6 +1686,23 @@ def main(): help='Disable torch.compile optimization (useful if your PyTorch build does not support it well)' ) + parser.add_argument( + '--check_capture_at_k', + action='store_true', + help='Run a one-time correctness check: slow vs fast Disease-Capture@K on a small subset' + ) + parser.add_argument( + '--profile_metrics', + action='store_true', + help='Print basic timings for CPU-side metric computations' + ) + parser.add_argument( + '--capture_check_n', + type=int, + default=200, + help='Number of samples used for the capture@K slow-vs-fast check (default: 200)' + ) + args = parser.parse_args() # Load model and dataset @@ -1575,6 +1720,9 @@ def main(): batch_size=args.batch_size, num_workers=args.num_workers, compile_model=(not args.no_compile), + check_capture_at_k=args.check_capture_at_k, + profile_metrics=args.profile_metrics, + capture_check_n=args.capture_check_n, ) # Run evaluation