Add fast vectorized implementation for Disease-Capture@K and enhance LandmarkEvaluator with profiling and correctness check options

This commit is contained in:
2026-01-18 18:14:45 +08:00
parent 6e76d67a10
commit 6de2132e84

View File

@@ -11,6 +11,7 @@ Implements the comprehensive evaluation framework defined in evaluate_design.md:
import argparse import argparse
import json import json
import os import os
import time
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Tuple, Optional
import warnings import warnings
@@ -87,6 +88,74 @@ def _to_float(x):
return np.nan return np.nan
def compute_disease_capture_at_k_fast(
y_true: np.ndarray,
y_scores: np.ndarray,
valid_mask: np.ndarray,
top_k_list: List[int],
) -> Dict[int, Dict[int, float]]:
"""Vectorized Disease-Capture@K.
Definition: for each disease d, among valid positives (y_true==1 and valid_mask),
capture@K is the fraction whose predicted top-K diseases contain d.
This implementation avoids per-positive full argsorts by computing top-k_max once
per sample (using argpartition), sorting those k_max indices by score, and then
aggregating hits via bincount.
"""
if y_scores.ndim != 2:
raise ValueError(
f"Expected y_scores 2D (N,K), got shape={y_scores.shape}")
if y_true.shape != y_scores.shape or valid_mask.shape != y_scores.shape:
raise ValueError(
f"Shape mismatch: y_true={y_true.shape}, y_scores={y_scores.shape}, valid_mask={valid_mask.shape}")
N, K = y_scores.shape
top_k_list = sorted({int(k) for k in top_k_list if int(k) > 0})
capture_rates: Dict[int, Dict[int, float]] = {
int(k): {} for k in top_k_list}
if N == 0 or K == 0 or len(top_k_list) == 0:
return capture_rates
topk_max = min(max(top_k_list), K)
# Valid positives per disease are the denominator.
pos_valid = (y_true == 1) & valid_mask.astype(bool)
denom = pos_valid.sum(axis=0).astype(np.int64) # (K,)
# Compute top-k_max indices per sample once (unordered), then sort those indices by score.
part = np.argpartition(y_scores, -topk_max,
axis=1)[:, -topk_max:] # (N, topk_max)
part_scores = np.take_along_axis(y_scores, part, axis=1)
order = np.argsort(part_scores, axis=1)[:, ::-1]
topk_idx = np.take_along_axis(
part, order, axis=1).astype(np.int32) # (N, topk_max)
rows = np.arange(N)[:, None]
for k_val in top_k_list:
k_eff = min(int(k_val), topk_max)
idx_k = topk_idx[:, :k_eff] # (N, k_eff)
# For each sample, we count a "hit" for disease d when:
# d is in top-K (true by construction for idx_k)
# AND sample is a valid positive for disease d.
hits_mask = pos_valid[rows, idx_k] # (N, k_eff) bool
hit_diseases = idx_k[hits_mask]
hits = np.bincount(hit_diseases, minlength=K).astype(np.int64)
# Convert to dict with NaNs for diseases with no valid positives.
out_k: Dict[int, float] = {}
with np.errstate(divide='ignore', invalid='ignore'):
frac = hits / denom
for d in range(K):
out_k[d] = float(frac[d]) if denom[d] > 0 else float('nan')
capture_rates[int(k_val)] = out_k
return capture_rates
def save_summary_json(summary: Dict, output_path: str) -> None: def save_summary_json(summary: Dict, output_path: str) -> None:
"""Save a single JSON summary file.""" """Save a single JSON summary file."""
@@ -266,6 +335,9 @@ class LandmarkEvaluator:
batch_size: int = 256, batch_size: int = 256,
num_workers: int = 4, num_workers: int = 4,
compile_model: bool = True, compile_model: bool = True,
check_capture_at_k: bool = False,
profile_metrics: bool = False,
capture_check_n: int = 200,
): ):
self.model = model.to(device).eval() self.model = model.to(device).eval()
self.head = head.to(device).eval() self.head = head.to(device).eval()
@@ -276,6 +348,11 @@ class LandmarkEvaluator:
self.batch_size = batch_size self.batch_size = batch_size
self.num_workers = num_workers self.num_workers = num_workers
self.check_capture_at_k = bool(check_capture_at_k)
self.profile_metrics = bool(profile_metrics)
self.capture_check_n = int(capture_check_n)
self._did_capture_check = False
use_cuda = str(self.device).startswith( use_cuda = str(self.device).startswith(
"cuda") and torch.cuda.is_available() "cuda") and torch.cuda.is_available()
if use_cuda: if use_cuda:
@@ -775,25 +852,26 @@ class LandmarkEvaluator:
Returns: Returns:
auc_scores: Dict mapping disease_idx -> AUC auc_scores: Dict mapping disease_idx -> AUC
""" """
auc_scores = {} auc_scores: Dict[int, float] = {}
n_diseases = risk_scores.shape[1] n_diseases = risk_scores.shape[1]
m = valid_mask.astype(bool)
# Fast pre-filter: skip diseases without both classes among valid entries.
pos = (labels == 1) & m
n_valid = m.sum(axis=0)
n_pos = pos.sum(axis=0)
n_neg = n_valid - n_pos
for k in range(n_diseases): for k in range(n_diseases):
mk = valid_mask[:, k] if n_valid[k] == 0 or n_pos[k] == 0 or n_neg[k] == 0:
if not np.any(mk):
auc_scores[k] = np.nan auc_scores[k] = np.nan
continue continue
y_true = labels[mk, k] mk = m[:, k]
y_score = risk_scores[mk, k] y_true_k = labels[mk, k]
y_score_k = risk_scores[mk, k]
# Check if we have both classes
if len(np.unique(y_true)) < 2:
auc_scores[k] = np.nan
else:
try: try:
auc = roc_auc_score(y_true, y_score) auc_scores[k] = float(roc_auc_score(y_true_k, y_score_k))
auc_scores[k] = auc
except Exception: except Exception:
auc_scores[k] = np.nan auc_scores[k] = np.nan
@@ -855,39 +933,89 @@ class LandmarkEvaluator:
Returns: Returns:
capture_rates: Dict[K_value][disease_idx] -> capture rate capture_rates: Dict[K_value][disease_idx] -> capture rate
""" """
capture_rates = {k: {} for k in self.top_k_values} # Fast path (vectorized): compute top-k_max once per sample.
t0 = time.perf_counter() if self.profile_metrics else None
capture_fast = compute_disease_capture_at_k_fast(
y_true=labels,
y_scores=risk_scores,
valid_mask=valid_mask,
top_k_list=self.top_k_values,
)
if self.profile_metrics and t0 is not None:
dt = time.perf_counter() - t0
print(
f" [profile] capture@K fast: {dt:.3f}s (N={risk_scores.shape[0]}, K={risk_scores.shape[1]})")
# Optional correctness check against the slow reference on a small subset.
if self.check_capture_at_k and (not self._did_capture_check):
self._did_capture_check = True
rng = np.random.default_rng(0)
n_sub = min(self.capture_check_n, risk_scores.shape[0])
sub_idx = rng.choice(
risk_scores.shape[0], size=n_sub, replace=False)
rs = risk_scores[sub_idx]
ys = labels[sub_idx]
vm = valid_mask[sub_idx]
t1 = time.perf_counter()
slow = self._compute_disease_capture_at_k_slow(rs, ys, vm)
t2 = time.perf_counter()
fast = compute_disease_capture_at_k_fast(
ys, rs, vm, self.top_k_values)
t3 = time.perf_counter()
def _eq(a: float, b: float) -> bool:
if np.isnan(a) and np.isnan(b):
return True
return float(a) == float(b)
for k_val in self.top_k_values:
for d in range(rs.shape[1]):
if not _eq(slow[int(k_val)][d], fast[int(k_val)][d]):
raise AssertionError(
f"Capture@{k_val} mismatch for disease {d}: slow={slow[int(k_val)][d]} fast={fast[int(k_val)][d]}"
)
print(
f" [check] capture@K ok on subset (N={n_sub}). slow={t2 - t1:.3f}s fast={t3 - t2:.3f}s"
)
return capture_fast
def _compute_disease_capture_at_k_slow(
self,
risk_scores: np.ndarray,
labels: np.ndarray,
valid_mask: np.ndarray,
) -> Dict[int, Dict[int, float]]:
"""Reference implementation (slow): kept for correctness checking."""
capture_rates = {int(k): {} for k in self.top_k_values}
n_diseases = risk_scores.shape[1] n_diseases = risk_scores.shape[1]
for disease_idx in range(n_diseases): for disease_idx in range(n_diseases):
mk = valid_mask[:, disease_idx] mk = valid_mask[:, disease_idx]
if not np.any(mk): if not np.any(mk):
for k in self.top_k_values: for k in self.top_k_values:
capture_rates[k][disease_idx] = np.nan capture_rates[int(k)][disease_idx] = np.nan
continue continue
y_true = labels[mk, disease_idx] y_true = labels[mk, disease_idx]
y_scores = risk_scores[mk] # (N_valid_k, K) y_scores = risk_scores[mk] # (N_valid_k, K)
# Find patients with positive label for this disease
pos_mask = y_true == 1 pos_mask = y_true == 1
if pos_mask.sum() == 0: if pos_mask.sum() == 0:
for k in self.top_k_values: for k in self.top_k_values:
capture_rates[k][disease_idx] = np.nan capture_rates[int(k)][disease_idx] = np.nan
continue continue
# For each positive patient, check if true disease is in top-K pos_idx = np.where(pos_mask)[0]
for k_val in self.top_k_values: for k_val in self.top_k_values:
captures = [] captures = []
for i in np.where(pos_mask)[0]: for i in pos_idx:
# Get top-K disease indices for this patient top_k_diseases = np.argsort(y_scores[i])[::-1][:int(k_val)]
top_k_diseases = np.argsort(y_scores[i])[::-1][:k_val] captures.append(int(disease_idx in top_k_diseases))
# Check if true disease is in top-K capture_rates[int(k_val)][disease_idx] = float(
is_captured = disease_idx in top_k_diseases np.mean(captures)) if captures else np.nan
captures.append(int(is_captured))
capture_rate = np.mean(captures) if captures else np.nan
capture_rates[k_val][disease_idx] = capture_rate
return capture_rates return capture_rates
@@ -1558,6 +1686,23 @@ def main():
help='Disable torch.compile optimization (useful if your PyTorch build does not support it well)' help='Disable torch.compile optimization (useful if your PyTorch build does not support it well)'
) )
parser.add_argument(
'--check_capture_at_k',
action='store_true',
help='Run a one-time correctness check: slow vs fast Disease-Capture@K on a small subset'
)
parser.add_argument(
'--profile_metrics',
action='store_true',
help='Print basic timings for CPU-side metric computations'
)
parser.add_argument(
'--capture_check_n',
type=int,
default=200,
help='Number of samples used for the capture@K slow-vs-fast check (default: 200)'
)
args = parser.parse_args() args = parser.parse_args()
# Load model and dataset # Load model and dataset
@@ -1575,6 +1720,9 @@ def main():
batch_size=args.batch_size, batch_size=args.batch_size,
num_workers=args.num_workers, num_workers=args.num_workers,
compile_model=(not args.no_compile), compile_model=(not args.no_compile),
check_capture_at_k=args.check_capture_at_k,
profile_metrics=args.profile_metrics,
capture_check_n=args.capture_check_n,
) )
# Run evaluation # Run evaluation