Files
DeepHealth/evaluate_horizon.py

278 lines
9.5 KiB
Python

"""Horizon-capture evaluation.
DISCLAIMERS (important):
- The reported AUC is "time-dependent" only because the label depends on the chosen horizon $\tau$.
Without explicit censoring / follow-up end times, this is NOT a classical risk-set AUC with IPCW.
Use it for model comparison and diagnostics, not strict statistical interpretation.
- The reported Brier scores are unadjusted diagnostic/proxy metrics (no censoring adjustment).
Use them to detect probability-mass compression / numerical stability issues; do not claim
calibrated absolute risk.
"""
import argparse
import os
from typing import Dict, List, Sequence
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from utils import (
EvalRecordDataset,
build_dataset_from_config,
build_event_driven_records,
build_model_head_criterion,
eval_collate_fn,
flatten_future_events,
get_test_subset,
load_checkpoint_into,
load_train_config,
make_inference_dataloader_kwargs,
parse_float_list,
predict_cifs,
roc_auc_ovr,
seed_everything,
topk_indices,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Evaluate horizon-capture using CIF at horizons")
p.add_argument("--run_dir", type=str, required=True)
p.add_argument(
"--horizons",
type=str,
nargs="+",
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"],
help="Horizon grid in years",
)
p.add_argument(
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
)
p.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else "cpu"),
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--min_pos", type=int, default=20)
p.add_argument(
"--topk_list",
type=int,
nargs="+",
default=[1, 5, 10, 20, 50],
)
return p.parse_args()
def build_labels_within_tau_flat(
n_records: int,
n_causes: int,
event_record_idx: np.ndarray,
event_cause: np.ndarray,
event_dt_years: np.ndarray,
tau_years: float,
) -> np.ndarray:
"""Build y_within_tau using a flattened (record,cause,dt) representation.
This preserves the exact label definition: y[i,k]=1 iff at least one event of cause k
occurs in (t0, t0+tau].
"""
y = np.zeros((n_records, n_causes), dtype=np.int8)
if event_dt_years.size == 0:
return y
m = event_dt_years <= float(tau_years)
if not np.any(m):
return y
y[event_record_idx[m], event_cause[m]] = 1
return y
def main() -> None:
args = parse_args()
seed_everything(args.seed)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
dataset = build_dataset_from_config(cfg)
test_subset = get_test_subset(dataset, cfg)
age_bins_years = parse_float_list(args.age_bins)
horizons = parse_float_list(args.horizons)
horizons = [float(h) for h in horizons]
records = build_event_driven_records(
dataset=dataset,
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
)
device = torch.device(args.device)
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(dataset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader(
rec_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=eval_collate_fn,
**dl_kwargs,
)
# Print disclaimers every run (requested)
print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).")
print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).")
scores = predict_cifs(model, head, criterion, loader,
horizons, device=device)
# scores shape: (N, K, H)
if scores.ndim != 3:
raise ValueError(
f"Expected CIF scores with shape (N,K,H), got {scores.shape}")
N, K, H = scores.shape
if N != len(records):
raise ValueError("Record count mismatch")
# Pre-flatten all future events once to avoid repeated per-record scans.
evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K)
per_tau_rows: List[Dict[str, object]] = []
per_cause_rows: List[Dict[str, object]] = []
workload_rows: List[Dict[str, object]] = []
for h_idx, tau in enumerate(horizons):
s_tau = scores[:, :, h_idx]
y_tau = build_labels_within_tau_flat(
N, K, evt_rec_idx, evt_cause, evt_dt, tau)
# Per-cause counts + Brier (vectorized)
n_pos = y_tau.sum(axis=0).astype(np.int64)
n_neg = (int(N) - n_pos).astype(np.int64)
# Brier per cause: mean_i (y - s)^2
brier_per_cause = np.mean(
(y_tau.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0)
brier_macro = float(np.mean(brier_per_cause)) if K > 0 else float("nan")
brier_weighted = float(np.sum(
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
# AUC: compute only for causes with enough positives and at least 1 negative
auc = np.full((K,), np.nan, dtype=np.float64)
min_pos = int(args.min_pos)
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
for k in candidates:
auc[k] = roc_auc_ovr(y_tau[:, k].astype(
np.int32), s_tau[:, k].astype(np.float64))
finite_auc = auc[np.isfinite(auc)]
auc_macro = float(np.mean(finite_auc)
) if finite_auc.size > 0 else float("nan")
w_mask = np.isfinite(auc)
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
n_valid_auc = int(np.isfinite(auc).sum())
# Append per-cause rows (vectorized via DataFrame to avoid Python loops)
per_cause_rows.append(
pd.DataFrame(
{
"tau_years": float(tau),
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
"n_neg": n_neg,
"auc": auc,
"brier": brier_per_cause,
}
)
)
# Business metrics for each topK
denom_true_pairs = int(y_tau.sum())
for topk in args.topk_list:
topk = int(topk)
idx = topk_indices(s_tau, topk)
captured = np.take_along_axis(y_tau, idx, axis=1)
hits = captured.sum(axis=1).astype(np.float64)
true_cnt = y_tau.sum(axis=1).astype(np.float64)
precision_like = hits / float(min(topk, K))
mean_precision = float(np.mean(precision_like)
) if N > 0 else float("nan")
mask_has_true = true_cnt > 0
recall_like = np.full((N,), np.nan, dtype=np.float64)
recall_like[mask_has_true] = hits[mask_has_true] / \
true_cnt[mask_has_true]
mean_recall = float(np.nanmean(recall_like)) if np.any(
mask_has_true) else float("nan")
median_recall = float(np.nanmedian(recall_like)) if np.any(
mask_has_true) else float("nan")
numer_captured_pairs = int(captured.sum())
pop_capture_rate = float(
numer_captured_pairs / denom_true_pairs) if denom_true_pairs > 0 else float("nan")
workload_rows.append(
{
"tau_years": float(tau),
"topk": int(topk),
"population_capture_rate": pop_capture_rate,
"mean_precision_like": mean_precision,
"mean_recall_like": mean_recall,
"median_recall_like": median_recall,
"denom_true_pairs": denom_true_pairs,
"numer_captured_pairs": numer_captured_pairs,
}
)
per_tau_rows.append(
{
"tau_years": float(tau),
"n_records": int(N),
"n_causes": int(K),
"auc_macro": auc_macro,
"auc_weighted_by_npos": auc_weighted,
"n_causes_valid_auc": int(n_valid_auc),
"brier_macro": brier_macro,
"brier_weighted_by_npos": brier_weighted,
"total_true_pairs": denom_true_pairs,
}
)
out_metrics = os.path.join(run_dir, "horizon_metrics.csv")
out_pc = os.path.join(run_dir, "horizon_per_cause.csv")
out_wy = os.path.join(run_dir, "workload_yield.csv")
pd.DataFrame(per_tau_rows).to_csv(out_metrics, index=False)
if per_cause_rows:
pd.concat(per_cause_rows, ignore_index=True).to_csv(out_pc, index=False)
else:
pd.DataFrame(columns=["tau_years", "cause_id", "n_pos",
"n_neg", "auc", "brier"]).to_csv(out_pc, index=False)
pd.DataFrame(workload_rows).to_csv(out_wy, index=False)
print(f"Wrote {out_metrics}")
print(f"Wrote {out_pc}")
print(f"Wrote {out_wy}")
if __name__ == "__main__":
main()