Add fast vectorized implementation for Disease-Capture@K and enhance LandmarkEvaluator with profiling and correctness check options
This commit is contained in:
204
evaluate.py
204
evaluate.py
@@ -11,6 +11,7 @@ Implements the comprehensive evaluation framework defined in evaluate_design.md:
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
import warnings
|
||||
@@ -87,6 +88,74 @@ def _to_float(x):
|
||||
return np.nan
|
||||
|
||||
|
||||
def compute_disease_capture_at_k_fast(
|
||||
y_true: np.ndarray,
|
||||
y_scores: np.ndarray,
|
||||
valid_mask: np.ndarray,
|
||||
top_k_list: List[int],
|
||||
) -> Dict[int, Dict[int, float]]:
|
||||
"""Vectorized Disease-Capture@K.
|
||||
|
||||
Definition: for each disease d, among valid positives (y_true==1 and valid_mask),
|
||||
capture@K is the fraction whose predicted top-K diseases contain d.
|
||||
|
||||
This implementation avoids per-positive full argsorts by computing top-k_max once
|
||||
per sample (using argpartition), sorting those k_max indices by score, and then
|
||||
aggregating hits via bincount.
|
||||
"""
|
||||
|
||||
if y_scores.ndim != 2:
|
||||
raise ValueError(
|
||||
f"Expected y_scores 2D (N,K), got shape={y_scores.shape}")
|
||||
if y_true.shape != y_scores.shape or valid_mask.shape != y_scores.shape:
|
||||
raise ValueError(
|
||||
f"Shape mismatch: y_true={y_true.shape}, y_scores={y_scores.shape}, valid_mask={valid_mask.shape}")
|
||||
|
||||
N, K = y_scores.shape
|
||||
top_k_list = sorted({int(k) for k in top_k_list if int(k) > 0})
|
||||
capture_rates: Dict[int, Dict[int, float]] = {
|
||||
int(k): {} for k in top_k_list}
|
||||
if N == 0 or K == 0 or len(top_k_list) == 0:
|
||||
return capture_rates
|
||||
|
||||
topk_max = min(max(top_k_list), K)
|
||||
|
||||
# Valid positives per disease are the denominator.
|
||||
pos_valid = (y_true == 1) & valid_mask.astype(bool)
|
||||
denom = pos_valid.sum(axis=0).astype(np.int64) # (K,)
|
||||
|
||||
# Compute top-k_max indices per sample once (unordered), then sort those indices by score.
|
||||
part = np.argpartition(y_scores, -topk_max,
|
||||
axis=1)[:, -topk_max:] # (N, topk_max)
|
||||
part_scores = np.take_along_axis(y_scores, part, axis=1)
|
||||
order = np.argsort(part_scores, axis=1)[:, ::-1]
|
||||
topk_idx = np.take_along_axis(
|
||||
part, order, axis=1).astype(np.int32) # (N, topk_max)
|
||||
|
||||
rows = np.arange(N)[:, None]
|
||||
|
||||
for k_val in top_k_list:
|
||||
k_eff = min(int(k_val), topk_max)
|
||||
idx_k = topk_idx[:, :k_eff] # (N, k_eff)
|
||||
|
||||
# For each sample, we count a "hit" for disease d when:
|
||||
# d is in top-K (true by construction for idx_k)
|
||||
# AND sample is a valid positive for disease d.
|
||||
hits_mask = pos_valid[rows, idx_k] # (N, k_eff) bool
|
||||
hit_diseases = idx_k[hits_mask]
|
||||
hits = np.bincount(hit_diseases, minlength=K).astype(np.int64)
|
||||
|
||||
# Convert to dict with NaNs for diseases with no valid positives.
|
||||
out_k: Dict[int, float] = {}
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
frac = hits / denom
|
||||
for d in range(K):
|
||||
out_k[d] = float(frac[d]) if denom[d] > 0 else float('nan')
|
||||
capture_rates[int(k_val)] = out_k
|
||||
|
||||
return capture_rates
|
||||
|
||||
|
||||
def save_summary_json(summary: Dict, output_path: str) -> None:
|
||||
"""Save a single JSON summary file."""
|
||||
|
||||
@@ -266,6 +335,9 @@ class LandmarkEvaluator:
|
||||
batch_size: int = 256,
|
||||
num_workers: int = 4,
|
||||
compile_model: bool = True,
|
||||
check_capture_at_k: bool = False,
|
||||
profile_metrics: bool = False,
|
||||
capture_check_n: int = 200,
|
||||
):
|
||||
self.model = model.to(device).eval()
|
||||
self.head = head.to(device).eval()
|
||||
@@ -276,6 +348,11 @@ class LandmarkEvaluator:
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
self.check_capture_at_k = bool(check_capture_at_k)
|
||||
self.profile_metrics = bool(profile_metrics)
|
||||
self.capture_check_n = int(capture_check_n)
|
||||
self._did_capture_check = False
|
||||
|
||||
use_cuda = str(self.device).startswith(
|
||||
"cuda") and torch.cuda.is_available()
|
||||
if use_cuda:
|
||||
@@ -775,27 +852,28 @@ class LandmarkEvaluator:
|
||||
Returns:
|
||||
auc_scores: Dict mapping disease_idx -> AUC
|
||||
"""
|
||||
auc_scores = {}
|
||||
auc_scores: Dict[int, float] = {}
|
||||
n_diseases = risk_scores.shape[1]
|
||||
m = valid_mask.astype(bool)
|
||||
|
||||
# Fast pre-filter: skip diseases without both classes among valid entries.
|
||||
pos = (labels == 1) & m
|
||||
n_valid = m.sum(axis=0)
|
||||
n_pos = pos.sum(axis=0)
|
||||
n_neg = n_valid - n_pos
|
||||
|
||||
for k in range(n_diseases):
|
||||
mk = valid_mask[:, k]
|
||||
if not np.any(mk):
|
||||
if n_valid[k] == 0 or n_pos[k] == 0 or n_neg[k] == 0:
|
||||
auc_scores[k] = np.nan
|
||||
continue
|
||||
|
||||
y_true = labels[mk, k]
|
||||
y_score = risk_scores[mk, k]
|
||||
|
||||
# Check if we have both classes
|
||||
if len(np.unique(y_true)) < 2:
|
||||
mk = m[:, k]
|
||||
y_true_k = labels[mk, k]
|
||||
y_score_k = risk_scores[mk, k]
|
||||
try:
|
||||
auc_scores[k] = float(roc_auc_score(y_true_k, y_score_k))
|
||||
except Exception:
|
||||
auc_scores[k] = np.nan
|
||||
else:
|
||||
try:
|
||||
auc = roc_auc_score(y_true, y_score)
|
||||
auc_scores[k] = auc
|
||||
except Exception:
|
||||
auc_scores[k] = np.nan
|
||||
|
||||
return auc_scores
|
||||
|
||||
@@ -855,39 +933,89 @@ class LandmarkEvaluator:
|
||||
Returns:
|
||||
capture_rates: Dict[K_value][disease_idx] -> capture rate
|
||||
"""
|
||||
capture_rates = {k: {} for k in self.top_k_values}
|
||||
# Fast path (vectorized): compute top-k_max once per sample.
|
||||
t0 = time.perf_counter() if self.profile_metrics else None
|
||||
capture_fast = compute_disease_capture_at_k_fast(
|
||||
y_true=labels,
|
||||
y_scores=risk_scores,
|
||||
valid_mask=valid_mask,
|
||||
top_k_list=self.top_k_values,
|
||||
)
|
||||
if self.profile_metrics and t0 is not None:
|
||||
dt = time.perf_counter() - t0
|
||||
print(
|
||||
f" [profile] capture@K fast: {dt:.3f}s (N={risk_scores.shape[0]}, K={risk_scores.shape[1]})")
|
||||
|
||||
# Optional correctness check against the slow reference on a small subset.
|
||||
if self.check_capture_at_k and (not self._did_capture_check):
|
||||
self._did_capture_check = True
|
||||
rng = np.random.default_rng(0)
|
||||
n_sub = min(self.capture_check_n, risk_scores.shape[0])
|
||||
sub_idx = rng.choice(
|
||||
risk_scores.shape[0], size=n_sub, replace=False)
|
||||
rs = risk_scores[sub_idx]
|
||||
ys = labels[sub_idx]
|
||||
vm = valid_mask[sub_idx]
|
||||
|
||||
t1 = time.perf_counter()
|
||||
slow = self._compute_disease_capture_at_k_slow(rs, ys, vm)
|
||||
t2 = time.perf_counter()
|
||||
fast = compute_disease_capture_at_k_fast(
|
||||
ys, rs, vm, self.top_k_values)
|
||||
t3 = time.perf_counter()
|
||||
|
||||
def _eq(a: float, b: float) -> bool:
|
||||
if np.isnan(a) and np.isnan(b):
|
||||
return True
|
||||
return float(a) == float(b)
|
||||
|
||||
for k_val in self.top_k_values:
|
||||
for d in range(rs.shape[1]):
|
||||
if not _eq(slow[int(k_val)][d], fast[int(k_val)][d]):
|
||||
raise AssertionError(
|
||||
f"Capture@{k_val} mismatch for disease {d}: slow={slow[int(k_val)][d]} fast={fast[int(k_val)][d]}"
|
||||
)
|
||||
|
||||
print(
|
||||
f" [check] capture@K ok on subset (N={n_sub}). slow={t2 - t1:.3f}s fast={t3 - t2:.3f}s"
|
||||
)
|
||||
|
||||
return capture_fast
|
||||
|
||||
def _compute_disease_capture_at_k_slow(
|
||||
self,
|
||||
risk_scores: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
valid_mask: np.ndarray,
|
||||
) -> Dict[int, Dict[int, float]]:
|
||||
"""Reference implementation (slow): kept for correctness checking."""
|
||||
capture_rates = {int(k): {} for k in self.top_k_values}
|
||||
n_diseases = risk_scores.shape[1]
|
||||
|
||||
for disease_idx in range(n_diseases):
|
||||
mk = valid_mask[:, disease_idx]
|
||||
if not np.any(mk):
|
||||
for k in self.top_k_values:
|
||||
capture_rates[k][disease_idx] = np.nan
|
||||
capture_rates[int(k)][disease_idx] = np.nan
|
||||
continue
|
||||
|
||||
y_true = labels[mk, disease_idx]
|
||||
y_scores = risk_scores[mk] # (N_valid_k, K)
|
||||
|
||||
# Find patients with positive label for this disease
|
||||
pos_mask = y_true == 1
|
||||
if pos_mask.sum() == 0:
|
||||
for k in self.top_k_values:
|
||||
capture_rates[k][disease_idx] = np.nan
|
||||
capture_rates[int(k)][disease_idx] = np.nan
|
||||
continue
|
||||
|
||||
# For each positive patient, check if true disease is in top-K
|
||||
pos_idx = np.where(pos_mask)[0]
|
||||
for k_val in self.top_k_values:
|
||||
captures = []
|
||||
for i in np.where(pos_mask)[0]:
|
||||
# Get top-K disease indices for this patient
|
||||
top_k_diseases = np.argsort(y_scores[i])[::-1][:k_val]
|
||||
# Check if true disease is in top-K
|
||||
is_captured = disease_idx in top_k_diseases
|
||||
captures.append(int(is_captured))
|
||||
|
||||
capture_rate = np.mean(captures) if captures else np.nan
|
||||
capture_rates[k_val][disease_idx] = capture_rate
|
||||
for i in pos_idx:
|
||||
top_k_diseases = np.argsort(y_scores[i])[::-1][:int(k_val)]
|
||||
captures.append(int(disease_idx in top_k_diseases))
|
||||
capture_rates[int(k_val)][disease_idx] = float(
|
||||
np.mean(captures)) if captures else np.nan
|
||||
|
||||
return capture_rates
|
||||
|
||||
@@ -1558,6 +1686,23 @@ def main():
|
||||
help='Disable torch.compile optimization (useful if your PyTorch build does not support it well)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--check_capture_at_k',
|
||||
action='store_true',
|
||||
help='Run a one-time correctness check: slow vs fast Disease-Capture@K on a small subset'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--profile_metrics',
|
||||
action='store_true',
|
||||
help='Print basic timings for CPU-side metric computations'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--capture_check_n',
|
||||
type=int,
|
||||
default=200,
|
||||
help='Number of samples used for the capture@K slow-vs-fast check (default: 200)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load model and dataset
|
||||
@@ -1575,6 +1720,9 @@ def main():
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
compile_model=(not args.no_compile),
|
||||
check_capture_at_k=args.check_capture_at_k,
|
||||
profile_metrics=args.profile_metrics,
|
||||
capture_check_n=args.capture_check_n,
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
|
||||
Reference in New Issue
Block a user