Add fast vectorized implementation for Disease-Capture@K and enhance LandmarkEvaluator with profiling and correctness check options

This commit is contained in:
2026-01-18 18:14:45 +08:00
parent 6e76d67a10
commit 6de2132e84

View File

@@ -11,6 +11,7 @@ Implements the comprehensive evaluation framework defined in evaluate_design.md:
import argparse
import json
import os
import time
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import warnings
@@ -87,6 +88,74 @@ def _to_float(x):
return np.nan
def compute_disease_capture_at_k_fast(
y_true: np.ndarray,
y_scores: np.ndarray,
valid_mask: np.ndarray,
top_k_list: List[int],
) -> Dict[int, Dict[int, float]]:
"""Vectorized Disease-Capture@K.
Definition: for each disease d, among valid positives (y_true==1 and valid_mask),
capture@K is the fraction whose predicted top-K diseases contain d.
This implementation avoids per-positive full argsorts by computing top-k_max once
per sample (using argpartition), sorting those k_max indices by score, and then
aggregating hits via bincount.
"""
if y_scores.ndim != 2:
raise ValueError(
f"Expected y_scores 2D (N,K), got shape={y_scores.shape}")
if y_true.shape != y_scores.shape or valid_mask.shape != y_scores.shape:
raise ValueError(
f"Shape mismatch: y_true={y_true.shape}, y_scores={y_scores.shape}, valid_mask={valid_mask.shape}")
N, K = y_scores.shape
top_k_list = sorted({int(k) for k in top_k_list if int(k) > 0})
capture_rates: Dict[int, Dict[int, float]] = {
int(k): {} for k in top_k_list}
if N == 0 or K == 0 or len(top_k_list) == 0:
return capture_rates
topk_max = min(max(top_k_list), K)
# Valid positives per disease are the denominator.
pos_valid = (y_true == 1) & valid_mask.astype(bool)
denom = pos_valid.sum(axis=0).astype(np.int64) # (K,)
# Compute top-k_max indices per sample once (unordered), then sort those indices by score.
part = np.argpartition(y_scores, -topk_max,
axis=1)[:, -topk_max:] # (N, topk_max)
part_scores = np.take_along_axis(y_scores, part, axis=1)
order = np.argsort(part_scores, axis=1)[:, ::-1]
topk_idx = np.take_along_axis(
part, order, axis=1).astype(np.int32) # (N, topk_max)
rows = np.arange(N)[:, None]
for k_val in top_k_list:
k_eff = min(int(k_val), topk_max)
idx_k = topk_idx[:, :k_eff] # (N, k_eff)
# For each sample, we count a "hit" for disease d when:
# d is in top-K (true by construction for idx_k)
# AND sample is a valid positive for disease d.
hits_mask = pos_valid[rows, idx_k] # (N, k_eff) bool
hit_diseases = idx_k[hits_mask]
hits = np.bincount(hit_diseases, minlength=K).astype(np.int64)
# Convert to dict with NaNs for diseases with no valid positives.
out_k: Dict[int, float] = {}
with np.errstate(divide='ignore', invalid='ignore'):
frac = hits / denom
for d in range(K):
out_k[d] = float(frac[d]) if denom[d] > 0 else float('nan')
capture_rates[int(k_val)] = out_k
return capture_rates
def save_summary_json(summary: Dict, output_path: str) -> None:
"""Save a single JSON summary file."""
@@ -266,6 +335,9 @@ class LandmarkEvaluator:
batch_size: int = 256,
num_workers: int = 4,
compile_model: bool = True,
check_capture_at_k: bool = False,
profile_metrics: bool = False,
capture_check_n: int = 200,
):
self.model = model.to(device).eval()
self.head = head.to(device).eval()
@@ -276,6 +348,11 @@ class LandmarkEvaluator:
self.batch_size = batch_size
self.num_workers = num_workers
self.check_capture_at_k = bool(check_capture_at_k)
self.profile_metrics = bool(profile_metrics)
self.capture_check_n = int(capture_check_n)
self._did_capture_check = False
use_cuda = str(self.device).startswith(
"cuda") and torch.cuda.is_available()
if use_cuda:
@@ -775,27 +852,28 @@ class LandmarkEvaluator:
Returns:
auc_scores: Dict mapping disease_idx -> AUC
"""
auc_scores = {}
auc_scores: Dict[int, float] = {}
n_diseases = risk_scores.shape[1]
m = valid_mask.astype(bool)
# Fast pre-filter: skip diseases without both classes among valid entries.
pos = (labels == 1) & m
n_valid = m.sum(axis=0)
n_pos = pos.sum(axis=0)
n_neg = n_valid - n_pos
for k in range(n_diseases):
mk = valid_mask[:, k]
if not np.any(mk):
if n_valid[k] == 0 or n_pos[k] == 0 or n_neg[k] == 0:
auc_scores[k] = np.nan
continue
y_true = labels[mk, k]
y_score = risk_scores[mk, k]
# Check if we have both classes
if len(np.unique(y_true)) < 2:
mk = m[:, k]
y_true_k = labels[mk, k]
y_score_k = risk_scores[mk, k]
try:
auc_scores[k] = float(roc_auc_score(y_true_k, y_score_k))
except Exception:
auc_scores[k] = np.nan
else:
try:
auc = roc_auc_score(y_true, y_score)
auc_scores[k] = auc
except Exception:
auc_scores[k] = np.nan
return auc_scores
@@ -855,39 +933,89 @@ class LandmarkEvaluator:
Returns:
capture_rates: Dict[K_value][disease_idx] -> capture rate
"""
capture_rates = {k: {} for k in self.top_k_values}
# Fast path (vectorized): compute top-k_max once per sample.
t0 = time.perf_counter() if self.profile_metrics else None
capture_fast = compute_disease_capture_at_k_fast(
y_true=labels,
y_scores=risk_scores,
valid_mask=valid_mask,
top_k_list=self.top_k_values,
)
if self.profile_metrics and t0 is not None:
dt = time.perf_counter() - t0
print(
f" [profile] capture@K fast: {dt:.3f}s (N={risk_scores.shape[0]}, K={risk_scores.shape[1]})")
# Optional correctness check against the slow reference on a small subset.
if self.check_capture_at_k and (not self._did_capture_check):
self._did_capture_check = True
rng = np.random.default_rng(0)
n_sub = min(self.capture_check_n, risk_scores.shape[0])
sub_idx = rng.choice(
risk_scores.shape[0], size=n_sub, replace=False)
rs = risk_scores[sub_idx]
ys = labels[sub_idx]
vm = valid_mask[sub_idx]
t1 = time.perf_counter()
slow = self._compute_disease_capture_at_k_slow(rs, ys, vm)
t2 = time.perf_counter()
fast = compute_disease_capture_at_k_fast(
ys, rs, vm, self.top_k_values)
t3 = time.perf_counter()
def _eq(a: float, b: float) -> bool:
if np.isnan(a) and np.isnan(b):
return True
return float(a) == float(b)
for k_val in self.top_k_values:
for d in range(rs.shape[1]):
if not _eq(slow[int(k_val)][d], fast[int(k_val)][d]):
raise AssertionError(
f"Capture@{k_val} mismatch for disease {d}: slow={slow[int(k_val)][d]} fast={fast[int(k_val)][d]}"
)
print(
f" [check] capture@K ok on subset (N={n_sub}). slow={t2 - t1:.3f}s fast={t3 - t2:.3f}s"
)
return capture_fast
def _compute_disease_capture_at_k_slow(
self,
risk_scores: np.ndarray,
labels: np.ndarray,
valid_mask: np.ndarray,
) -> Dict[int, Dict[int, float]]:
"""Reference implementation (slow): kept for correctness checking."""
capture_rates = {int(k): {} for k in self.top_k_values}
n_diseases = risk_scores.shape[1]
for disease_idx in range(n_diseases):
mk = valid_mask[:, disease_idx]
if not np.any(mk):
for k in self.top_k_values:
capture_rates[k][disease_idx] = np.nan
capture_rates[int(k)][disease_idx] = np.nan
continue
y_true = labels[mk, disease_idx]
y_scores = risk_scores[mk] # (N_valid_k, K)
# Find patients with positive label for this disease
pos_mask = y_true == 1
if pos_mask.sum() == 0:
for k in self.top_k_values:
capture_rates[k][disease_idx] = np.nan
capture_rates[int(k)][disease_idx] = np.nan
continue
# For each positive patient, check if true disease is in top-K
pos_idx = np.where(pos_mask)[0]
for k_val in self.top_k_values:
captures = []
for i in np.where(pos_mask)[0]:
# Get top-K disease indices for this patient
top_k_diseases = np.argsort(y_scores[i])[::-1][:k_val]
# Check if true disease is in top-K
is_captured = disease_idx in top_k_diseases
captures.append(int(is_captured))
capture_rate = np.mean(captures) if captures else np.nan
capture_rates[k_val][disease_idx] = capture_rate
for i in pos_idx:
top_k_diseases = np.argsort(y_scores[i])[::-1][:int(k_val)]
captures.append(int(disease_idx in top_k_diseases))
capture_rates[int(k_val)][disease_idx] = float(
np.mean(captures)) if captures else np.nan
return capture_rates
@@ -1558,6 +1686,23 @@ def main():
help='Disable torch.compile optimization (useful if your PyTorch build does not support it well)'
)
parser.add_argument(
'--check_capture_at_k',
action='store_true',
help='Run a one-time correctness check: slow vs fast Disease-Capture@K on a small subset'
)
parser.add_argument(
'--profile_metrics',
action='store_true',
help='Print basic timings for CPU-side metric computations'
)
parser.add_argument(
'--capture_check_n',
type=int,
default=200,
help='Number of samples used for the capture@K slow-vs-fast check (default: 200)'
)
args = parser.parse_args()
# Load model and dataset
@@ -1575,6 +1720,9 @@ def main():
batch_size=args.batch_size,
num_workers=args.num_workers,
compile_model=(not args.no_compile),
check_capture_at_k=args.check_capture_at_k,
profile_metrics=args.profile_metrics,
capture_check_n=args.capture_check_n,
)
# Run evaluation