Enhance compute_disease_capture_at_k_fast to support return counts and update LandmarkEvaluator for backward compatibility with new metrics structure

This commit is contained in:
2026-01-19 00:01:21 +08:00
parent d13fa430b7
commit 76d3fed76f

View File

@@ -94,7 +94,8 @@ def compute_disease_capture_at_k_fast(
y_scores: np.ndarray,
valid_mask: np.ndarray,
top_k_list: List[int],
) -> Dict[int, Dict[int, float]]:
return_counts: bool = False,
) -> Dict[int, Dict[int, float]] | Tuple[Dict[int, Dict[int, float]], np.ndarray, Dict[int, np.ndarray]]:
"""Vectorized Disease-Capture@K.
Definition: for each disease d, among valid positives (y_true==1 and valid_mask),
@@ -116,7 +117,10 @@ def compute_disease_capture_at_k_fast(
top_k_list = sorted({int(k) for k in top_k_list if int(k) > 0})
capture_rates: Dict[int, Dict[int, float]] = {
int(k): {} for k in top_k_list}
hits_by_k: Dict[int, np.ndarray] = {}
if N == 0 or K == 0 or len(top_k_list) == 0:
if return_counts:
return capture_rates, np.zeros((K,), dtype=np.int64), hits_by_k
return capture_rates
topk_max = min(max(top_k_list), K)
@@ -145,6 +149,8 @@ def compute_disease_capture_at_k_fast(
hits_mask = pos_valid[rows, idx_k] # (N, k_eff) bool
hit_diseases = idx_k[hits_mask]
hits = np.bincount(hit_diseases, minlength=K).astype(np.int64)
if return_counts:
hits_by_k[int(k_val)] = hits
# Convert to dict with NaNs for diseases with no valid positives.
out_k: Dict[int, float] = {}
@@ -154,6 +160,8 @@ def compute_disease_capture_at_k_fast(
out_k[d] = float(frac[d]) if denom[d] > 0 else float('nan')
capture_rates[int(k_val)] = out_k
if return_counts:
return capture_rates, denom, hits_by_k
return capture_rates
@@ -197,6 +205,7 @@ def save_results_csv_bundle(results: Dict, out_dir: str) -> Dict[str, str]:
summary_rows: List[Dict] = []
auc_rows: List[Dict] = []
capture_rows: List[Dict] = []
capture_mean_rows: List[Dict] = []
lift_rows: List[Dict] = []
dca_rows: List[Dict] = []
@@ -233,16 +242,70 @@ def save_results_csv_bundle(results: Dict, out_dir: str) -> Dict[str, str]:
if track == 'complete_case':
capture = track_res.get('disease_capture_at_k') or {}
for k_val, per_disease in capture.items():
# Backward-compatible parsing:
# - new format: {per_disease: {k: {d: rate}}, n_positive: {d: n}, macro_avg: {k: x}, micro_avg: {k: y}}
# - old format: {k: {d: rate}}
if isinstance(capture, dict) and 'per_disease' in capture:
per_disease_by_k = capture.get('per_disease') or {}
n_positive_by_disease = capture.get('n_positive') or {}
macro_by_k = capture.get('macro_avg') or {}
micro_by_k = capture.get('micro_avg') or {}
else:
per_disease_by_k = capture
n_positive_by_disease = {}
macro_by_k = {}
micro_by_k = {}
# Per-disease rows (+ n_positive)
for k_val, per_disease in (per_disease_by_k or {}).items():
try:
k_int = int(k_val)
except Exception:
continue
for disease_idx, rate in (per_disease or {}).items():
try:
d_int = int(disease_idx)
except Exception:
continue
n_pos = n_positive_by_disease.get(
d_int, n_positive_by_disease.get(str(d_int), np.nan))
capture_rows.append({
'age_cutoff': age,
'horizon': horizon,
'track': track,
'k': k_int,
'disease_idx': int(disease_idx),
'disease_idx': d_int,
'capture_rate': _to_float(rate),
'n_positive': _to_float(n_pos),
})
# Macro/Micro summary rows
# Prefer explicit macro/micro from the new format; otherwise compute macro from rates.
for k_val, per_disease in (per_disease_by_k or {}).items():
try:
k_int = int(k_val)
except Exception:
continue
macro = macro_by_k.get(
k_int, macro_by_k.get(str(k_int), np.nan))
micro = micro_by_k.get(
k_int, micro_by_k.get(str(k_int), np.nan))
if macro is None or (isinstance(macro, float) and np.isnan(macro)):
rates = [
_to_float(r) for r in (per_disease or {}).values()
]
macro = float(np.nanmean(rates)) if len(
rates) else np.nan
capture_mean_rows.append({
'age_cutoff': age,
'horizon': horizon,
'track': track,
'k': k_int,
'macro_avg': _to_float(macro),
'micro_avg': _to_float(micro),
})
lift_yield = track_res.get('lift_and_yield') or {}
@@ -307,6 +370,11 @@ def save_results_csv_bundle(results: Dict, out_dir: str) -> Dict[str, str]:
df_capture.to_csv(capture_path, index=False)
paths['capture_at_k'] = capture_path
df_capture_mean = pd.DataFrame(capture_mean_rows)
capture_mean_path = os.path.join(out_dir, 'capture_at_k_mean.csv')
df_capture_mean.to_csv(capture_mean_path, index=False)
paths['capture_at_k_mean'] = capture_mean_path
df_lift = pd.DataFrame(lift_rows)
lift_path = os.path.join(out_dir, 'lift_yield.csv')
df_lift.to_csv(lift_path, index=False)
@@ -921,7 +989,7 @@ class LandmarkEvaluator:
risk_scores: np.ndarray,
labels: np.ndarray,
valid_mask: np.ndarray,
) -> Dict[int, Dict[int, float]]:
) -> Dict[str, Dict]:
"""
Compute Disease-Capture@K: fraction of true positives where the true
disease appears in the patient's top-K predicted risks.
@@ -932,15 +1000,21 @@ class LandmarkEvaluator:
valid_mask: (N, K) boolean mask
Returns:
capture_rates: Dict[K_value][disease_idx] -> capture rate
metrics: Dict with keys:
- per_disease: Dict[k][disease_idx] -> capture rate
- n_positive: Dict[disease_idx] -> number of valid positives (support)
- n_captured: Dict[k][disease_idx] -> number of captured positives
- macro_avg: Dict[k] -> macro-average capture rate
- micro_avg: Dict[k] -> micro-average capture rate
"""
# Fast path (vectorized): compute top-k_max once per sample.
t0 = time.perf_counter() if self.profile_metrics else None
capture_fast = compute_disease_capture_at_k_fast(
capture_rates, denom, hits_by_k = compute_disease_capture_at_k_fast(
y_true=labels,
y_scores=risk_scores,
valid_mask=valid_mask,
top_k_list=self.top_k_values,
return_counts=True,
)
if self.profile_metrics and t0 is not None:
dt = time.perf_counter() - t0
@@ -981,7 +1055,36 @@ class LandmarkEvaluator:
f" [check] capture@K ok on subset (N={n_sub}). slow={t2 - t1:.3f}s fast={t3 - t2:.3f}s"
)
return capture_fast
K = int(risk_scores.shape[1])
n_positive: Dict[int, int] = {int(d): int(denom[d]) for d in range(K)}
n_captured: Dict[int, Dict[int, int]] = {}
macro_avg: Dict[int, float] = {}
micro_avg: Dict[int, float] = {}
total_pos = int(denom.sum())
for k_val in self.top_k_values:
k_int = int(k_val)
hits = hits_by_k.get(k_int, np.zeros((K,), dtype=np.int64))
n_captured[k_int] = {int(d): int(hits[d]) for d in range(K)}
# Macro: mean across diseases with support (ignore NaNs)
rates = capture_rates.get(k_int, {})
rate_values = np.array([rates.get(d, np.nan)
for d in range(K)], dtype=np.float64)
macro_avg[k_int] = float(np.nanmean(
rate_values)) if rate_values.size else float('nan')
# Micro: sum captured / sum positives
micro_avg[k_int] = float(
hits.sum() / total_pos) if total_pos > 0 else float('nan')
return {
'per_disease': capture_rates,
'n_positive': n_positive,
'n_captured': n_captured,
'macro_avg': macro_avg,
'micro_avg': micro_avg,
}
def _compute_disease_capture_at_k_slow(
self,
@@ -1627,9 +1730,15 @@ def print_summary(results: Dict):
if 'disease_capture_at_k' in cc:
print(f" Disease Capture:")
for k in [5, 10, 20, 50]:
if k in cc['disease_capture_at_k']:
rates = list(cc['disease_capture_at_k'][k].values())
mean_rate = np.nanmean(rates)
capture = cc.get('disease_capture_at_k') or {}
if isinstance(capture, dict) and 'macro_avg' in capture:
macro = (capture.get('macro_avg') or {}).get(k, np.nan)
micro = (capture.get('micro_avg') or {}).get(k, np.nan)
print(
f" Top-{k}: macro={_to_float(macro):.3f}, micro={_to_float(micro):.3f}")
elif isinstance(capture, dict) and k in capture:
rates = list((capture.get(k) or {}).values())
mean_rate = np.nanmean([_to_float(r) for r in rates])
print(f" Top-{k}: {mean_rate:.3f}")
# Show lift and yield