449 lines
16 KiB
Python
449 lines
16 KiB
Python
"""Horizon-capture evaluation (event-driven, age-stratified).
|
||
|
||
This script implements the protocol described in 评估方案.md:
|
||
|
||
- Age-stratified evaluation: metrics are computed independently within each age bin (no mixing).
|
||
- Event-driven inclusion: each (person, age_bin) yields a record iff DOA <= bin upper bound and
|
||
there is at least one disease event in the bin; baseline t0 is sampled randomly from in-bin
|
||
disease events with t0 >= DOA.
|
||
- No follow-up completeness filtering (no t0+tau <= t_end constraint).
|
||
|
||
Primary outputs per age bin:
|
||
- Top-K Event Capture@tau (event-count based)
|
||
- Workload–Yield curves (Top-p% people by a person-level horizon score)
|
||
|
||
Secondary (diagnostic-only) outputs per age bin:
|
||
- Approximate event-driven AUC / Brier (no IPCW, no censoring adjustment)
|
||
"""
|
||
|
||
import argparse
|
||
import math
|
||
import os
|
||
from typing import Dict, List, Sequence, Tuple
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
|
||
try:
|
||
from tqdm import tqdm
|
||
except Exception: # pragma: no cover
|
||
tqdm = None
|
||
|
||
from utils import (
|
||
EvalRecordDataset,
|
||
build_dataset_from_config,
|
||
build_event_driven_records,
|
||
build_model_head_criterion,
|
||
eval_collate_fn,
|
||
flatten_future_events,
|
||
get_test_subset,
|
||
load_checkpoint_into,
|
||
load_train_config,
|
||
make_inference_dataloader_kwargs,
|
||
parse_float_list,
|
||
predict_cifs,
|
||
roc_auc_ovr,
|
||
seed_everything,
|
||
topk_indices,
|
||
DAYS_PER_YEAR,
|
||
)
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
p = argparse.ArgumentParser(
|
||
description="Evaluate horizon-capture using CIF at horizons")
|
||
p.add_argument("--run_dir", type=str, required=True)
|
||
p.add_argument(
|
||
"--horizons",
|
||
type=str,
|
||
nargs="+",
|
||
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"],
|
||
help="Horizon grid in years",
|
||
)
|
||
p.add_argument(
|
||
"--age_bins",
|
||
type=str,
|
||
nargs="+",
|
||
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
|
||
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
|
||
)
|
||
|
||
p.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default=("cuda" if torch.cuda.is_available() else "cpu"),
|
||
)
|
||
p.add_argument("--batch_size", type=int, default=256)
|
||
p.add_argument("--num_workers", type=int, default=0)
|
||
p.add_argument("--seed", type=int, default=0)
|
||
p.add_argument("--min_pos", type=int, default=20)
|
||
p.add_argument(
|
||
"--topk_list",
|
||
type=int,
|
||
nargs="+",
|
||
default=[5, 10, 20, 50],
|
||
)
|
||
p.add_argument(
|
||
"--workload_fracs",
|
||
type=float,
|
||
nargs="+",
|
||
default=[0.01, 0.02, 0.05, 0.1, 0.2, 0.5],
|
||
help="Fractions for workload–yield curves (Top-p%% people).",
|
||
)
|
||
p.add_argument(
|
||
"--no_tqdm",
|
||
action="store_true",
|
||
help="Disable tqdm progress bars",
|
||
)
|
||
return p.parse_args()
|
||
|
||
|
||
def _format_age_bin_label(lo: float, hi: float) -> str:
|
||
if np.isinf(hi):
|
||
return f"[{lo}, inf)"
|
||
return f"[{lo}, {hi})"
|
||
|
||
|
||
def _assign_age_bin_idx(t0_days: np.ndarray, age_bins_years: Sequence[float]) -> np.ndarray:
|
||
age_bins_years = np.asarray(list(age_bins_years), dtype=np.float64)
|
||
age_bins_days = age_bins_years * DAYS_PER_YEAR
|
||
return np.searchsorted(age_bins_days, t0_days, side="left") - 1
|
||
|
||
|
||
def _event_counts_within_tau(
|
||
n_records: int,
|
||
event_record_idx: np.ndarray,
|
||
event_dt_years: np.ndarray,
|
||
tau_years: float,
|
||
) -> np.ndarray:
|
||
"""Count events within (t0, t0+tau] per record (event-count, not unique causes)."""
|
||
if event_record_idx.size == 0:
|
||
return np.zeros((n_records,), dtype=np.int64)
|
||
m = event_dt_years <= float(tau_years)
|
||
if not np.any(m):
|
||
return np.zeros((n_records,), dtype=np.int64)
|
||
return np.bincount(event_record_idx[m], minlength=n_records).astype(np.int64)
|
||
|
||
|
||
def build_labels_within_tau_flat(
|
||
n_records: int,
|
||
n_causes: int,
|
||
event_record_idx: np.ndarray,
|
||
event_cause: np.ndarray,
|
||
event_dt_years: np.ndarray,
|
||
tau_years: float,
|
||
) -> np.ndarray:
|
||
"""Build y_within_tau using a flattened (record,cause,dt) representation.
|
||
|
||
This preserves the exact label definition: y[i,k]=1 iff at least one event of cause k
|
||
occurs in (t0, t0+tau].
|
||
"""
|
||
y = np.zeros((n_records, n_causes), dtype=np.int8)
|
||
if event_dt_years.size == 0:
|
||
return y
|
||
m = event_dt_years <= float(tau_years)
|
||
if not np.any(m):
|
||
return y
|
||
y[event_record_idx[m], event_cause[m]] = 1
|
||
return y
|
||
|
||
|
||
def main() -> None:
|
||
args = parse_args()
|
||
seed_everything(args.seed)
|
||
|
||
show_progress = (not args.no_tqdm)
|
||
|
||
run_dir = args.run_dir
|
||
cfg = load_train_config(run_dir)
|
||
|
||
dataset = build_dataset_from_config(cfg)
|
||
test_subset = get_test_subset(dataset, cfg)
|
||
|
||
age_bins_years = parse_float_list(args.age_bins)
|
||
horizons = parse_float_list(args.horizons)
|
||
horizons = [float(h) for h in horizons]
|
||
|
||
records = build_event_driven_records(
|
||
subset=test_subset,
|
||
age_bins_years=age_bins_years,
|
||
seed=args.seed,
|
||
show_progress=show_progress,
|
||
)
|
||
|
||
device = torch.device(args.device)
|
||
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
||
load_checkpoint_into(run_dir, model, head, criterion, device)
|
||
|
||
rec_ds = EvalRecordDataset(test_subset, records)
|
||
|
||
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
||
loader = DataLoader(
|
||
rec_ds,
|
||
batch_size=args.batch_size,
|
||
shuffle=False,
|
||
num_workers=args.num_workers,
|
||
collate_fn=eval_collate_fn,
|
||
**dl_kwargs,
|
||
)
|
||
|
||
# Print disclaimers every run (requested)
|
||
print("PRIMARY METRICS: event-count Capture@K and Workload–Yield, computed independently per age bin.")
|
||
print("DIAGNOSTICS ONLY: AUC/Brier below are event-driven approximations (no IPCW / censoring adjustment).")
|
||
|
||
scores = predict_cifs(
|
||
model,
|
||
head,
|
||
criterion,
|
||
loader,
|
||
horizons,
|
||
device=device,
|
||
show_progress=show_progress,
|
||
progress_desc="Inference (horizons)",
|
||
)
|
||
# scores shape: (N, K, H)
|
||
if scores.ndim != 3:
|
||
raise ValueError(
|
||
f"Expected CIF scores with shape (N,K,H), got {scores.shape}")
|
||
|
||
N, K, H = scores.shape
|
||
if N != len(records):
|
||
raise ValueError("Record count mismatch")
|
||
|
||
# Pre-flatten all future events once to avoid repeated per-record scans.
|
||
# NOTE: these are event-level arrays (not unique causes), suitable for event-count Capture@K.
|
||
evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K)
|
||
|
||
# Assign each record to an age bin (based on t0; by construction t0 is within the bin).
|
||
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
|
||
bin_idx = _assign_age_bin_idx(t0_days, age_bins_years)
|
||
age_bins_years_arr = np.asarray(list(age_bins_years), dtype=np.float64)
|
||
|
||
capture_rows: List[Dict[str, object]] = []
|
||
workload_rows: List[Dict[str, object]] = []
|
||
|
||
# Diagnostics (optional): approximate event-driven AUC/Brier computed per bin.
|
||
diag_rows: List[Dict[str, object]] = []
|
||
diag_per_cause_parts: List[pd.DataFrame] = []
|
||
|
||
bins_iter = range(len(age_bins_years_arr) - 1)
|
||
if show_progress and tqdm is not None:
|
||
bins_iter = tqdm(bins_iter, total=len(
|
||
age_bins_years_arr) - 1, desc="Age bins")
|
||
|
||
for b in bins_iter:
|
||
lo = float(age_bins_years_arr[b])
|
||
hi = float(age_bins_years_arr[b + 1])
|
||
age_label = _format_age_bin_label(lo, hi)
|
||
|
||
m_rec = bin_idx == b
|
||
n_bin = int(m_rec.sum())
|
||
if n_bin == 0:
|
||
continue
|
||
|
||
rec_idx_bin = np.flatnonzero(m_rec).astype(np.int32)
|
||
|
||
# Filter events to this bin's records once.
|
||
m_evt_bin = m_rec[evt_rec_idx] if evt_rec_idx.size > 0 else np.zeros(
|
||
(0,), dtype=bool)
|
||
evt_rec_idx_b = evt_rec_idx[m_evt_bin]
|
||
evt_cause_b = evt_cause[m_evt_bin]
|
||
evt_dt_b = evt_dt[m_evt_bin]
|
||
|
||
horizon_iter = enumerate(horizons)
|
||
if show_progress and tqdm is not None:
|
||
horizon_iter = tqdm(horizon_iter, total=len(
|
||
horizons), desc=f"Horizons {age_label}")
|
||
|
||
# Precompute a local index mapping for diagnostics label building.
|
||
local_map = np.full((N,), -1, dtype=np.int32)
|
||
local_map[rec_idx_bin] = np.arange(n_bin, dtype=np.int32)
|
||
|
||
for h_idx, tau in horizon_iter:
|
||
s_tau_all = scores[:, :, h_idx]
|
||
s_tau = s_tau_all[m_rec]
|
||
|
||
# -------------------------
|
||
# Primary metric: Top-K Event Capture@tau (event-count based)
|
||
# -------------------------
|
||
denom_events = int(np.sum(evt_dt_b <= float(tau))
|
||
) if evt_dt_b.size > 0 else 0
|
||
if denom_events == 0:
|
||
for topk in args.topk_list:
|
||
capture_rows.append(
|
||
{
|
||
"age_bin": age_label,
|
||
"tau_years": float(tau),
|
||
"topk": int(topk),
|
||
"capture_at_k": float("nan"),
|
||
"denom_events": int(0),
|
||
"numer_events": int(0),
|
||
"n_records": int(n_bin),
|
||
"n_causes": int(K),
|
||
}
|
||
)
|
||
else:
|
||
m_evt_tau = evt_dt_b <= float(tau)
|
||
evt_rec_idx_tau = evt_rec_idx_b[m_evt_tau]
|
||
evt_cause_tau = evt_cause_b[m_evt_tau]
|
||
|
||
# For each K, compute whether each event's cause is in that record's Top-K list.
|
||
for topk in args.topk_list:
|
||
topk = int(topk)
|
||
idx = topk_indices(s_tau_all, topk) # shape (N, topk)
|
||
idx_for_events = idx[evt_rec_idx_tau]
|
||
hits = (idx_for_events ==
|
||
evt_cause_tau[:, None]).any(axis=1)
|
||
numer_events = int(hits.sum())
|
||
capture = float(numer_events / denom_events)
|
||
capture_rows.append(
|
||
{
|
||
"age_bin": age_label,
|
||
"tau_years": float(tau),
|
||
"topk": int(topk),
|
||
"capture_at_k": capture,
|
||
"denom_events": int(denom_events),
|
||
"numer_events": int(numer_events),
|
||
"n_records": int(n_bin),
|
||
"n_causes": int(K),
|
||
}
|
||
)
|
||
|
||
# -------------------------
|
||
# Primary metric: Workload–Yield (Top-p% people)
|
||
# -------------------------
|
||
# Person-level score: max_k CIF_k(tau). This is used only for workload–yield ranking.
|
||
person_score = s_tau.max(axis=1) if K > 0 else np.zeros(
|
||
(n_bin,), dtype=np.float64)
|
||
order = np.argsort(-person_score, kind="mergesort")
|
||
|
||
counts_per_record = _event_counts_within_tau(
|
||
n_bin, local_map[evt_rec_idx_b], evt_dt_b, tau)
|
||
total_events = int(counts_per_record.sum())
|
||
overall_events_per_person = (
|
||
total_events / float(n_bin)) if n_bin > 0 else float("nan")
|
||
|
||
for frac in args.workload_fracs:
|
||
frac = float(frac)
|
||
if frac <= 0.0:
|
||
continue
|
||
n_sel = int(math.ceil(frac * n_bin))
|
||
n_sel = min(max(n_sel, 1), n_bin)
|
||
sel_local = order[:n_sel]
|
||
events_captured = int(counts_per_record[sel_local].sum())
|
||
capture_rate = float(
|
||
events_captured / total_events) if total_events > 0 else float("nan")
|
||
|
||
selected_events_per_person = (
|
||
events_captured / float(n_sel)) if n_sel > 0 else float("nan")
|
||
lift = (selected_events_per_person /
|
||
overall_events_per_person) if overall_events_per_person > 0 else float("nan")
|
||
|
||
workload_rows.append(
|
||
{
|
||
"age_bin": age_label,
|
||
"tau_years": float(tau),
|
||
"frac_selected": float(frac),
|
||
"n_selected": int(n_sel),
|
||
"n_records": int(n_bin),
|
||
"total_events": int(total_events),
|
||
"events_captured": int(events_captured),
|
||
"capture_rate": capture_rate,
|
||
"lift_events_per_person": float(lift),
|
||
"person_score_def": "max_k_CIF_k(tau)",
|
||
}
|
||
)
|
||
|
||
# -------------------------
|
||
# Diagnostics (optional): approximate event-driven AUC/Brier
|
||
# -------------------------
|
||
# Convert event-level data to binary labels y[i,k]=1 iff >=1 event of cause k within tau.
|
||
y_tau_bin = np.zeros((n_bin, K), dtype=np.int8)
|
||
if evt_dt_b.size > 0:
|
||
m_evt_tau = evt_dt_b <= float(tau)
|
||
if np.any(m_evt_tau):
|
||
rec_local = local_map[evt_rec_idx_b[m_evt_tau]]
|
||
valid = rec_local >= 0
|
||
y_tau_bin[rec_local[valid],
|
||
evt_cause_b[m_evt_tau][valid]] = 1
|
||
|
||
n_pos = y_tau_bin.sum(axis=0).astype(np.int64)
|
||
n_neg = (int(n_bin) - n_pos).astype(np.int64)
|
||
|
||
brier_per_cause = np.mean(
|
||
(y_tau_bin.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0
|
||
)
|
||
brier_macro = float(np.mean(brier_per_cause)
|
||
) if K > 0 else float("nan")
|
||
brier_weighted = float(np.sum(
|
||
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
|
||
|
||
auc = np.full((K,), np.nan, dtype=np.float64)
|
||
min_pos = int(args.min_pos)
|
||
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
|
||
for k in candidates:
|
||
auc[k] = roc_auc_ovr(y_tau_bin[:, k].astype(
|
||
np.int32), s_tau[:, k].astype(np.float64))
|
||
|
||
finite_auc = auc[np.isfinite(auc)]
|
||
auc_macro = float(np.mean(finite_auc)
|
||
) if finite_auc.size > 0 else float("nan")
|
||
w_mask = np.isfinite(auc)
|
||
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
|
||
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
|
||
n_valid_auc = int(np.isfinite(auc).sum())
|
||
|
||
diag_rows.append(
|
||
{
|
||
"age_bin": age_label,
|
||
"tau_years": float(tau),
|
||
"n_records": int(n_bin),
|
||
"n_causes": int(K),
|
||
"auc_macro": auc_macro,
|
||
"auc_weighted_by_npos": auc_weighted,
|
||
"n_causes_valid_auc": int(n_valid_auc),
|
||
"brier_macro": brier_macro,
|
||
"brier_weighted_by_npos": brier_weighted,
|
||
}
|
||
)
|
||
|
||
diag_per_cause_parts.append(
|
||
pd.DataFrame(
|
||
{
|
||
"age_bin": age_label,
|
||
"tau_years": float(tau),
|
||
"cause_id": np.arange(K, dtype=np.int64),
|
||
"n_pos": n_pos,
|
||
"n_neg": n_neg,
|
||
"auc": auc,
|
||
"brier": brier_per_cause,
|
||
}
|
||
)
|
||
)
|
||
|
||
out_capture = os.path.join(run_dir, "horizon_capture.csv")
|
||
out_wy = os.path.join(run_dir, "workload_yield.csv")
|
||
out_diag = os.path.join(run_dir, "horizon_metrics.csv")
|
||
out_diag_pc = os.path.join(run_dir, "horizon_per_cause.csv")
|
||
|
||
pd.DataFrame(capture_rows).to_csv(out_capture, index=False)
|
||
pd.DataFrame(workload_rows).to_csv(out_wy, index=False)
|
||
pd.DataFrame(diag_rows).to_csv(out_diag, index=False)
|
||
if diag_per_cause_parts:
|
||
pd.concat(diag_per_cause_parts, ignore_index=True).to_csv(
|
||
out_diag_pc, index=False)
|
||
else:
|
||
pd.DataFrame(columns=["age_bin", "tau_years", "cause_id", "n_pos",
|
||
"n_neg", "auc", "brier"]).to_csv(out_diag_pc, index=False)
|
||
|
||
print(f"Wrote {out_capture}")
|
||
print(f"Wrote {out_wy}")
|
||
print(f"Wrote {out_diag} (diagnostic-only)")
|
||
print(f"Wrote {out_diag_pc} (diagnostic-only)")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|