Add evaluation scripts for next-event prediction and horizon-capture evaluation with detailed metric disclaimers
This commit is contained in:
277
evaluate_horizon.py
Normal file
277
evaluate_horizon.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Horizon-capture evaluation.
|
||||
|
||||
DISCLAIMERS (important):
|
||||
- The reported AUC is "time-dependent" only because the label depends on the chosen horizon $\tau$.
|
||||
Without explicit censoring / follow-up end times, this is NOT a classical risk-set AUC with IPCW.
|
||||
Use it for model comparison and diagnostics, not strict statistical interpretation.
|
||||
|
||||
- The reported Brier scores are unadjusted diagnostic/proxy metrics (no censoring adjustment).
|
||||
Use them to detect probability-mass compression / numerical stability issues; do not claim
|
||||
calibrated absolute risk.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from utils import (
|
||||
EvalRecordDataset,
|
||||
build_dataset_from_config,
|
||||
build_event_driven_records,
|
||||
build_model_head_criterion,
|
||||
eval_collate_fn,
|
||||
flatten_future_events,
|
||||
get_test_subset,
|
||||
load_checkpoint_into,
|
||||
load_train_config,
|
||||
make_inference_dataloader_kwargs,
|
||||
parse_float_list,
|
||||
predict_cifs,
|
||||
roc_auc_ovr,
|
||||
seed_everything,
|
||||
topk_indices,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Evaluate horizon-capture using CIF at horizons")
|
||||
p.add_argument("--run_dir", type=str, required=True)
|
||||
p.add_argument(
|
||||
"--horizons",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"],
|
||||
help="Horizon grid in years",
|
||||
)
|
||||
p.add_argument(
|
||||
"--age_bins",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
|
||||
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default=("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
p.add_argument("--batch_size", type=int, default=256)
|
||||
p.add_argument("--num_workers", type=int, default=0)
|
||||
p.add_argument("--seed", type=int, default=0)
|
||||
p.add_argument("--min_pos", type=int, default=20)
|
||||
p.add_argument(
|
||||
"--topk_list",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 5, 10, 20, 50],
|
||||
)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def build_labels_within_tau_flat(
|
||||
n_records: int,
|
||||
n_causes: int,
|
||||
event_record_idx: np.ndarray,
|
||||
event_cause: np.ndarray,
|
||||
event_dt_years: np.ndarray,
|
||||
tau_years: float,
|
||||
) -> np.ndarray:
|
||||
"""Build y_within_tau using a flattened (record,cause,dt) representation.
|
||||
|
||||
This preserves the exact label definition: y[i,k]=1 iff at least one event of cause k
|
||||
occurs in (t0, t0+tau].
|
||||
"""
|
||||
y = np.zeros((n_records, n_causes), dtype=np.int8)
|
||||
if event_dt_years.size == 0:
|
||||
return y
|
||||
m = event_dt_years <= float(tau_years)
|
||||
if not np.any(m):
|
||||
return y
|
||||
y[event_record_idx[m], event_cause[m]] = 1
|
||||
return y
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
seed_everything(args.seed)
|
||||
|
||||
run_dir = args.run_dir
|
||||
cfg = load_train_config(run_dir)
|
||||
|
||||
dataset = build_dataset_from_config(cfg)
|
||||
test_subset = get_test_subset(dataset, cfg)
|
||||
|
||||
age_bins_years = parse_float_list(args.age_bins)
|
||||
horizons = parse_float_list(args.horizons)
|
||||
horizons = [float(h) for h in horizons]
|
||||
|
||||
records = build_event_driven_records(
|
||||
dataset=dataset,
|
||||
subset=test_subset,
|
||||
age_bins_years=age_bins_years,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
device = torch.device(args.device)
|
||||
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
||||
load_checkpoint_into(run_dir, model, head, criterion, device)
|
||||
|
||||
rec_ds = EvalRecordDataset(dataset, records)
|
||||
|
||||
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
||||
loader = DataLoader(
|
||||
rec_ds,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
collate_fn=eval_collate_fn,
|
||||
**dl_kwargs,
|
||||
)
|
||||
|
||||
# Print disclaimers every run (requested)
|
||||
print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).")
|
||||
print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).")
|
||||
|
||||
scores = predict_cifs(model, head, criterion, loader,
|
||||
horizons, device=device)
|
||||
# scores shape: (N, K, H)
|
||||
if scores.ndim != 3:
|
||||
raise ValueError(
|
||||
f"Expected CIF scores with shape (N,K,H), got {scores.shape}")
|
||||
|
||||
N, K, H = scores.shape
|
||||
if N != len(records):
|
||||
raise ValueError("Record count mismatch")
|
||||
|
||||
# Pre-flatten all future events once to avoid repeated per-record scans.
|
||||
evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K)
|
||||
|
||||
per_tau_rows: List[Dict[str, object]] = []
|
||||
per_cause_rows: List[Dict[str, object]] = []
|
||||
workload_rows: List[Dict[str, object]] = []
|
||||
|
||||
for h_idx, tau in enumerate(horizons):
|
||||
s_tau = scores[:, :, h_idx]
|
||||
y_tau = build_labels_within_tau_flat(
|
||||
N, K, evt_rec_idx, evt_cause, evt_dt, tau)
|
||||
|
||||
# Per-cause counts + Brier (vectorized)
|
||||
n_pos = y_tau.sum(axis=0).astype(np.int64)
|
||||
n_neg = (int(N) - n_pos).astype(np.int64)
|
||||
|
||||
# Brier per cause: mean_i (y - s)^2
|
||||
brier_per_cause = np.mean(
|
||||
(y_tau.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0)
|
||||
brier_macro = float(np.mean(brier_per_cause)) if K > 0 else float("nan")
|
||||
brier_weighted = float(np.sum(
|
||||
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
|
||||
|
||||
# AUC: compute only for causes with enough positives and at least 1 negative
|
||||
auc = np.full((K,), np.nan, dtype=np.float64)
|
||||
min_pos = int(args.min_pos)
|
||||
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
|
||||
for k in candidates:
|
||||
auc[k] = roc_auc_ovr(y_tau[:, k].astype(
|
||||
np.int32), s_tau[:, k].astype(np.float64))
|
||||
|
||||
finite_auc = auc[np.isfinite(auc)]
|
||||
auc_macro = float(np.mean(finite_auc)
|
||||
) if finite_auc.size > 0 else float("nan")
|
||||
w_mask = np.isfinite(auc)
|
||||
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
|
||||
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
|
||||
n_valid_auc = int(np.isfinite(auc).sum())
|
||||
|
||||
# Append per-cause rows (vectorized via DataFrame to avoid Python loops)
|
||||
per_cause_rows.append(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"tau_years": float(tau),
|
||||
"cause_id": np.arange(K, dtype=np.int64),
|
||||
"n_pos": n_pos,
|
||||
"n_neg": n_neg,
|
||||
"auc": auc,
|
||||
"brier": brier_per_cause,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Business metrics for each topK
|
||||
denom_true_pairs = int(y_tau.sum())
|
||||
for topk in args.topk_list:
|
||||
topk = int(topk)
|
||||
idx = topk_indices(s_tau, topk)
|
||||
captured = np.take_along_axis(y_tau, idx, axis=1)
|
||||
hits = captured.sum(axis=1).astype(np.float64)
|
||||
true_cnt = y_tau.sum(axis=1).astype(np.float64)
|
||||
|
||||
precision_like = hits / float(min(topk, K))
|
||||
mean_precision = float(np.mean(precision_like)
|
||||
) if N > 0 else float("nan")
|
||||
|
||||
mask_has_true = true_cnt > 0
|
||||
recall_like = np.full((N,), np.nan, dtype=np.float64)
|
||||
recall_like[mask_has_true] = hits[mask_has_true] / \
|
||||
true_cnt[mask_has_true]
|
||||
mean_recall = float(np.nanmean(recall_like)) if np.any(
|
||||
mask_has_true) else float("nan")
|
||||
median_recall = float(np.nanmedian(recall_like)) if np.any(
|
||||
mask_has_true) else float("nan")
|
||||
|
||||
numer_captured_pairs = int(captured.sum())
|
||||
pop_capture_rate = float(
|
||||
numer_captured_pairs / denom_true_pairs) if denom_true_pairs > 0 else float("nan")
|
||||
|
||||
workload_rows.append(
|
||||
{
|
||||
"tau_years": float(tau),
|
||||
"topk": int(topk),
|
||||
"population_capture_rate": pop_capture_rate,
|
||||
"mean_precision_like": mean_precision,
|
||||
"mean_recall_like": mean_recall,
|
||||
"median_recall_like": median_recall,
|
||||
"denom_true_pairs": denom_true_pairs,
|
||||
"numer_captured_pairs": numer_captured_pairs,
|
||||
}
|
||||
)
|
||||
|
||||
per_tau_rows.append(
|
||||
{
|
||||
"tau_years": float(tau),
|
||||
"n_records": int(N),
|
||||
"n_causes": int(K),
|
||||
"auc_macro": auc_macro,
|
||||
"auc_weighted_by_npos": auc_weighted,
|
||||
"n_causes_valid_auc": int(n_valid_auc),
|
||||
"brier_macro": brier_macro,
|
||||
"brier_weighted_by_npos": brier_weighted,
|
||||
"total_true_pairs": denom_true_pairs,
|
||||
}
|
||||
)
|
||||
|
||||
out_metrics = os.path.join(run_dir, "horizon_metrics.csv")
|
||||
out_pc = os.path.join(run_dir, "horizon_per_cause.csv")
|
||||
out_wy = os.path.join(run_dir, "workload_yield.csv")
|
||||
|
||||
pd.DataFrame(per_tau_rows).to_csv(out_metrics, index=False)
|
||||
if per_cause_rows:
|
||||
pd.concat(per_cause_rows, ignore_index=True).to_csv(out_pc, index=False)
|
||||
else:
|
||||
pd.DataFrame(columns=["tau_years", "cause_id", "n_pos",
|
||||
"n_neg", "auc", "brier"]).to_csv(out_pc, index=False)
|
||||
pd.DataFrame(workload_rows).to_csv(out_wy, index=False)
|
||||
|
||||
print(f"Wrote {out_metrics}")
|
||||
print(f"Wrote {out_pc}")
|
||||
print(f"Wrote {out_wy}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user