"""Horizon-capture evaluation (event-driven, age-stratified). This script implements the protocol described in 评估方案.md: - Age-stratified evaluation: metrics are computed independently within each age bin (no mixing). - Event-driven inclusion: each (person, age_bin) yields a record iff DOA <= bin upper bound and there is at least one disease event in the bin; baseline t0 is sampled randomly from in-bin disease events with t0 >= DOA. - No follow-up completeness filtering (no t0+tau <= t_end constraint). Primary outputs per age bin: - Top-K Event Capture@tau (event-count based) - Workload–Yield curves (Top-p% people by a person-level horizon score) Secondary (diagnostic-only) outputs per age bin: - Approximate event-driven AUC / Brier (no IPCW, no censoring adjustment) """ import argparse import math import os from typing import Dict, List, Sequence, Tuple import numpy as np import pandas as pd import torch from torch.utils.data import DataLoader try: from tqdm import tqdm except Exception: # pragma: no cover tqdm = None from utils import ( EvalRecordDataset, build_dataset_from_config, build_event_driven_records, build_model_head_criterion, eval_collate_fn, flatten_future_events, get_test_subset, load_checkpoint_into, load_train_config, make_inference_dataloader_kwargs, parse_float_list, predict_cifs, roc_auc_ovr, seed_everything, topk_indices, DAYS_PER_YEAR, ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Evaluate horizon-capture using CIF at horizons") p.add_argument("--run_dir", type=str, required=True) p.add_argument( "--horizons", type=str, nargs="+", default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"], help="Horizon grid in years", ) p.add_argument( "--age_bins", type=str, nargs="+", default=["40", "45", "50", "55", "60", "65", "70", "inf"], help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)", ) p.add_argument( "--device", type=str, default=("cuda" if torch.cuda.is_available() else "cpu"), ) p.add_argument("--batch_size", type=int, default=256) p.add_argument("--num_workers", type=int, default=0) p.add_argument("--seed", type=int, default=0) p.add_argument("--min_pos", type=int, default=20) p.add_argument( "--topk_list", type=int, nargs="+", default=[5, 10, 20, 50], ) p.add_argument( "--workload_fracs", type=float, nargs="+", default=[0.01, 0.02, 0.05, 0.1, 0.2, 0.5], help="Fractions for workload–yield curves (Top-p%% people).", ) p.add_argument( "--no_tqdm", action="store_true", help="Disable tqdm progress bars", ) return p.parse_args() def _format_age_bin_label(lo: float, hi: float) -> str: if np.isinf(hi): return f"[{lo}, inf)" return f"[{lo}, {hi})" def _assign_age_bin_idx(t0_days: np.ndarray, age_bins_years: Sequence[float]) -> np.ndarray: age_bins_years = np.asarray(list(age_bins_years), dtype=np.float64) age_bins_days = age_bins_years * DAYS_PER_YEAR return np.searchsorted(age_bins_days, t0_days, side="left") - 1 def _event_counts_within_tau( n_records: int, event_record_idx: np.ndarray, event_dt_years: np.ndarray, tau_years: float, ) -> np.ndarray: """Count events within (t0, t0+tau] per record (event-count, not unique causes).""" if event_record_idx.size == 0: return np.zeros((n_records,), dtype=np.int64) m = event_dt_years <= float(tau_years) if not np.any(m): return np.zeros((n_records,), dtype=np.int64) return np.bincount(event_record_idx[m], minlength=n_records).astype(np.int64) def build_labels_within_tau_flat( n_records: int, n_causes: int, event_record_idx: np.ndarray, event_cause: np.ndarray, event_dt_years: np.ndarray, tau_years: float, ) -> np.ndarray: """Build y_within_tau using a flattened (record,cause,dt) representation. This preserves the exact label definition: y[i,k]=1 iff at least one event of cause k occurs in (t0, t0+tau]. """ y = np.zeros((n_records, n_causes), dtype=np.int8) if event_dt_years.size == 0: return y m = event_dt_years <= float(tau_years) if not np.any(m): return y y[event_record_idx[m], event_cause[m]] = 1 return y def main() -> None: args = parse_args() seed_everything(args.seed) show_progress = (not args.no_tqdm) run_dir = args.run_dir cfg = load_train_config(run_dir) dataset = build_dataset_from_config(cfg) test_subset = get_test_subset(dataset, cfg) age_bins_years = parse_float_list(args.age_bins) horizons = parse_float_list(args.horizons) horizons = [float(h) for h in horizons] records = build_event_driven_records( subset=test_subset, age_bins_years=age_bins_years, seed=args.seed, show_progress=show_progress, ) device = torch.device(args.device) model, head, criterion = build_model_head_criterion(cfg, dataset, device) load_checkpoint_into(run_dir, model, head, criterion, device) rec_ds = EvalRecordDataset(test_subset, records) dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) loader = DataLoader( rec_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=eval_collate_fn, **dl_kwargs, ) # Print disclaimers every run (requested) print("PRIMARY METRICS: event-count Capture@K and Workload–Yield, computed independently per age bin.") print("DIAGNOSTICS ONLY: AUC/Brier below are event-driven approximations (no IPCW / censoring adjustment).") scores = predict_cifs( model, head, criterion, loader, horizons, device=device, show_progress=show_progress, progress_desc="Inference (horizons)", ) # scores shape: (N, K, H) if scores.ndim != 3: raise ValueError( f"Expected CIF scores with shape (N,K,H), got {scores.shape}") N, K, H = scores.shape if N != len(records): raise ValueError("Record count mismatch") # Pre-flatten all future events once to avoid repeated per-record scans. # NOTE: these are event-level arrays (not unique causes), suitable for event-count Capture@K. evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K) # Assign each record to an age bin (based on t0; by construction t0 is within the bin). t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64) bin_idx = _assign_age_bin_idx(t0_days, age_bins_years) age_bins_years_arr = np.asarray(list(age_bins_years), dtype=np.float64) capture_rows: List[Dict[str, object]] = [] workload_rows: List[Dict[str, object]] = [] # Diagnostics (optional): approximate event-driven AUC/Brier computed per bin. diag_rows: List[Dict[str, object]] = [] diag_per_cause_parts: List[pd.DataFrame] = [] bins_iter = range(len(age_bins_years_arr) - 1) if show_progress and tqdm is not None: bins_iter = tqdm(bins_iter, total=len( age_bins_years_arr) - 1, desc="Age bins") for b in bins_iter: lo = float(age_bins_years_arr[b]) hi = float(age_bins_years_arr[b + 1]) age_label = _format_age_bin_label(lo, hi) m_rec = bin_idx == b n_bin = int(m_rec.sum()) if n_bin == 0: continue rec_idx_bin = np.flatnonzero(m_rec).astype(np.int32) # Filter events to this bin's records once. m_evt_bin = m_rec[evt_rec_idx] if evt_rec_idx.size > 0 else np.zeros( (0,), dtype=bool) evt_rec_idx_b = evt_rec_idx[m_evt_bin] evt_cause_b = evt_cause[m_evt_bin] evt_dt_b = evt_dt[m_evt_bin] horizon_iter = enumerate(horizons) if show_progress and tqdm is not None: horizon_iter = tqdm(horizon_iter, total=len( horizons), desc=f"Horizons {age_label}") # Precompute a local index mapping for diagnostics label building. local_map = np.full((N,), -1, dtype=np.int32) local_map[rec_idx_bin] = np.arange(n_bin, dtype=np.int32) for h_idx, tau in horizon_iter: s_tau_all = scores[:, :, h_idx] s_tau = s_tau_all[m_rec] # ------------------------- # Primary metric: Top-K Event Capture@tau (event-count based) # ------------------------- denom_events = int(np.sum(evt_dt_b <= float(tau)) ) if evt_dt_b.size > 0 else 0 if denom_events == 0: for topk in args.topk_list: capture_rows.append( { "age_bin": age_label, "tau_years": float(tau), "topk": int(topk), "capture_at_k": float("nan"), "denom_events": int(0), "numer_events": int(0), "n_records": int(n_bin), "n_causes": int(K), } ) else: m_evt_tau = evt_dt_b <= float(tau) evt_rec_idx_tau = evt_rec_idx_b[m_evt_tau] evt_cause_tau = evt_cause_b[m_evt_tau] # For each K, compute whether each event's cause is in that record's Top-K list. for topk in args.topk_list: topk = int(topk) idx = topk_indices(s_tau_all, topk) # shape (N, topk) idx_for_events = idx[evt_rec_idx_tau] hits = (idx_for_events == evt_cause_tau[:, None]).any(axis=1) numer_events = int(hits.sum()) capture = float(numer_events / denom_events) capture_rows.append( { "age_bin": age_label, "tau_years": float(tau), "topk": int(topk), "capture_at_k": capture, "denom_events": int(denom_events), "numer_events": int(numer_events), "n_records": int(n_bin), "n_causes": int(K), } ) # ------------------------- # Primary metric: Workload–Yield (Top-p% people) # ------------------------- # Person-level score: max_k CIF_k(tau). This is used only for workload–yield ranking. person_score = s_tau.max(axis=1) if K > 0 else np.zeros( (n_bin,), dtype=np.float64) order = np.argsort(-person_score, kind="mergesort") counts_per_record = _event_counts_within_tau( n_bin, local_map[evt_rec_idx_b], evt_dt_b, tau) total_events = int(counts_per_record.sum()) overall_events_per_person = ( total_events / float(n_bin)) if n_bin > 0 else float("nan") for frac in args.workload_fracs: frac = float(frac) if frac <= 0.0: continue n_sel = int(math.ceil(frac * n_bin)) n_sel = min(max(n_sel, 1), n_bin) sel_local = order[:n_sel] events_captured = int(counts_per_record[sel_local].sum()) capture_rate = float( events_captured / total_events) if total_events > 0 else float("nan") selected_events_per_person = ( events_captured / float(n_sel)) if n_sel > 0 else float("nan") lift = (selected_events_per_person / overall_events_per_person) if overall_events_per_person > 0 else float("nan") workload_rows.append( { "age_bin": age_label, "tau_years": float(tau), "frac_selected": float(frac), "n_selected": int(n_sel), "n_records": int(n_bin), "total_events": int(total_events), "events_captured": int(events_captured), "capture_rate": capture_rate, "lift_events_per_person": float(lift), "person_score_def": "max_k_CIF_k(tau)", } ) # ------------------------- # Diagnostics (optional): approximate event-driven AUC/Brier # ------------------------- # Convert event-level data to binary labels y[i,k]=1 iff >=1 event of cause k within tau. y_tau_bin = np.zeros((n_bin, K), dtype=np.int8) if evt_dt_b.size > 0: m_evt_tau = evt_dt_b <= float(tau) if np.any(m_evt_tau): rec_local = local_map[evt_rec_idx_b[m_evt_tau]] valid = rec_local >= 0 y_tau_bin[rec_local[valid], evt_cause_b[m_evt_tau][valid]] = 1 n_pos = y_tau_bin.sum(axis=0).astype(np.int64) n_neg = (int(n_bin) - n_pos).astype(np.int64) brier_per_cause = np.mean( (y_tau_bin.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0 ) brier_macro = float(np.mean(brier_per_cause) ) if K > 0 else float("nan") brier_weighted = float(np.sum( brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan") auc = np.full((K,), np.nan, dtype=np.float64) min_pos = int(args.min_pos) candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0)) for k in candidates: auc[k] = roc_auc_ovr(y_tau_bin[:, k].astype( np.int32), s_tau[:, k].astype(np.float64)) finite_auc = auc[np.isfinite(auc)] auc_macro = float(np.mean(finite_auc) ) if finite_auc.size > 0 else float("nan") w_mask = np.isfinite(auc) auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum( n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan") n_valid_auc = int(np.isfinite(auc).sum()) diag_rows.append( { "age_bin": age_label, "tau_years": float(tau), "n_records": int(n_bin), "n_causes": int(K), "auc_macro": auc_macro, "auc_weighted_by_npos": auc_weighted, "n_causes_valid_auc": int(n_valid_auc), "brier_macro": brier_macro, "brier_weighted_by_npos": brier_weighted, } ) diag_per_cause_parts.append( pd.DataFrame( { "age_bin": age_label, "tau_years": float(tau), "cause_id": np.arange(K, dtype=np.int64), "n_pos": n_pos, "n_neg": n_neg, "auc": auc, "brier": brier_per_cause, } ) ) out_capture = os.path.join(run_dir, "horizon_capture.csv") out_wy = os.path.join(run_dir, "workload_yield.csv") out_diag = os.path.join(run_dir, "horizon_metrics.csv") out_diag_pc = os.path.join(run_dir, "horizon_per_cause.csv") pd.DataFrame(capture_rows).to_csv(out_capture, index=False) pd.DataFrame(workload_rows).to_csv(out_wy, index=False) pd.DataFrame(diag_rows).to_csv(out_diag, index=False) if diag_per_cause_parts: pd.concat(diag_per_cause_parts, ignore_index=True).to_csv( out_diag_pc, index=False) else: pd.DataFrame(columns=["age_bin", "tau_years", "cause_id", "n_pos", "n_neg", "auc", "brier"]).to_csv(out_diag_pc, index=False) print(f"Wrote {out_capture}") print(f"Wrote {out_wy}") print(f"Wrote {out_diag} (diagnostic-only)") print(f"Wrote {out_diag_pc} (diagnostic-only)") if __name__ == "__main__": main()