"""Horizon-capture evaluation. DISCLAIMERS (important): - The reported AUC is "time-dependent" only because the label depends on the chosen horizon $\tau$. Without explicit censoring / follow-up end times, this is NOT a classical risk-set AUC with IPCW. Use it for model comparison and diagnostics, not strict statistical interpretation. - The reported Brier scores are unadjusted diagnostic/proxy metrics (no censoring adjustment). Use them to detect probability-mass compression / numerical stability issues; do not claim calibrated absolute risk. """ import argparse import os from typing import Dict, List, Sequence import numpy as np import pandas as pd import torch from torch.utils.data import DataLoader from utils import ( EvalRecordDataset, build_dataset_from_config, build_event_driven_records, build_model_head_criterion, eval_collate_fn, flatten_future_events, get_test_subset, load_checkpoint_into, load_train_config, make_inference_dataloader_kwargs, parse_float_list, predict_cifs, roc_auc_ovr, seed_everything, topk_indices, ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Evaluate horizon-capture using CIF at horizons") p.add_argument("--run_dir", type=str, required=True) p.add_argument( "--horizons", type=str, nargs="+", default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"], help="Horizon grid in years", ) p.add_argument( "--age_bins", type=str, nargs="+", default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"], help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)", ) p.add_argument( "--device", type=str, default=("cuda" if torch.cuda.is_available() else "cpu"), ) p.add_argument("--batch_size", type=int, default=256) p.add_argument("--num_workers", type=int, default=0) p.add_argument("--seed", type=int, default=0) p.add_argument("--min_pos", type=int, default=20) p.add_argument( "--topk_list", type=int, nargs="+", default=[1, 5, 10, 20, 50], ) return p.parse_args() def build_labels_within_tau_flat( n_records: int, n_causes: int, event_record_idx: np.ndarray, event_cause: np.ndarray, event_dt_years: np.ndarray, tau_years: float, ) -> np.ndarray: """Build y_within_tau using a flattened (record,cause,dt) representation. This preserves the exact label definition: y[i,k]=1 iff at least one event of cause k occurs in (t0, t0+tau]. """ y = np.zeros((n_records, n_causes), dtype=np.int8) if event_dt_years.size == 0: return y m = event_dt_years <= float(tau_years) if not np.any(m): return y y[event_record_idx[m], event_cause[m]] = 1 return y def main() -> None: args = parse_args() seed_everything(args.seed) run_dir = args.run_dir cfg = load_train_config(run_dir) dataset = build_dataset_from_config(cfg) test_subset = get_test_subset(dataset, cfg) age_bins_years = parse_float_list(args.age_bins) horizons = parse_float_list(args.horizons) horizons = [float(h) for h in horizons] records = build_event_driven_records( dataset=dataset, subset=test_subset, age_bins_years=age_bins_years, seed=args.seed, ) device = torch.device(args.device) model, head, criterion = build_model_head_criterion(cfg, dataset, device) load_checkpoint_into(run_dir, model, head, criterion, device) rec_ds = EvalRecordDataset(dataset, records) dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) loader = DataLoader( rec_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=eval_collate_fn, **dl_kwargs, ) # Print disclaimers every run (requested) print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).") print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).") scores = predict_cifs(model, head, criterion, loader, horizons, device=device) # scores shape: (N, K, H) if scores.ndim != 3: raise ValueError( f"Expected CIF scores with shape (N,K,H), got {scores.shape}") N, K, H = scores.shape if N != len(records): raise ValueError("Record count mismatch") # Pre-flatten all future events once to avoid repeated per-record scans. evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K) per_tau_rows: List[Dict[str, object]] = [] per_cause_rows: List[Dict[str, object]] = [] workload_rows: List[Dict[str, object]] = [] for h_idx, tau in enumerate(horizons): s_tau = scores[:, :, h_idx] y_tau = build_labels_within_tau_flat( N, K, evt_rec_idx, evt_cause, evt_dt, tau) # Per-cause counts + Brier (vectorized) n_pos = y_tau.sum(axis=0).astype(np.int64) n_neg = (int(N) - n_pos).astype(np.int64) # Brier per cause: mean_i (y - s)^2 brier_per_cause = np.mean( (y_tau.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0) brier_macro = float(np.mean(brier_per_cause)) if K > 0 else float("nan") brier_weighted = float(np.sum( brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan") # AUC: compute only for causes with enough positives and at least 1 negative auc = np.full((K,), np.nan, dtype=np.float64) min_pos = int(args.min_pos) candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0)) for k in candidates: auc[k] = roc_auc_ovr(y_tau[:, k].astype( np.int32), s_tau[:, k].astype(np.float64)) finite_auc = auc[np.isfinite(auc)] auc_macro = float(np.mean(finite_auc) ) if finite_auc.size > 0 else float("nan") w_mask = np.isfinite(auc) auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum( n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan") n_valid_auc = int(np.isfinite(auc).sum()) # Append per-cause rows (vectorized via DataFrame to avoid Python loops) per_cause_rows.append( pd.DataFrame( { "tau_years": float(tau), "cause_id": np.arange(K, dtype=np.int64), "n_pos": n_pos, "n_neg": n_neg, "auc": auc, "brier": brier_per_cause, } ) ) # Business metrics for each topK denom_true_pairs = int(y_tau.sum()) for topk in args.topk_list: topk = int(topk) idx = topk_indices(s_tau, topk) captured = np.take_along_axis(y_tau, idx, axis=1) hits = captured.sum(axis=1).astype(np.float64) true_cnt = y_tau.sum(axis=1).astype(np.float64) precision_like = hits / float(min(topk, K)) mean_precision = float(np.mean(precision_like) ) if N > 0 else float("nan") mask_has_true = true_cnt > 0 recall_like = np.full((N,), np.nan, dtype=np.float64) recall_like[mask_has_true] = hits[mask_has_true] / \ true_cnt[mask_has_true] mean_recall = float(np.nanmean(recall_like)) if np.any( mask_has_true) else float("nan") median_recall = float(np.nanmedian(recall_like)) if np.any( mask_has_true) else float("nan") numer_captured_pairs = int(captured.sum()) pop_capture_rate = float( numer_captured_pairs / denom_true_pairs) if denom_true_pairs > 0 else float("nan") workload_rows.append( { "tau_years": float(tau), "topk": int(topk), "population_capture_rate": pop_capture_rate, "mean_precision_like": mean_precision, "mean_recall_like": mean_recall, "median_recall_like": median_recall, "denom_true_pairs": denom_true_pairs, "numer_captured_pairs": numer_captured_pairs, } ) per_tau_rows.append( { "tau_years": float(tau), "n_records": int(N), "n_causes": int(K), "auc_macro": auc_macro, "auc_weighted_by_npos": auc_weighted, "n_causes_valid_auc": int(n_valid_auc), "brier_macro": brier_macro, "brier_weighted_by_npos": brier_weighted, "total_true_pairs": denom_true_pairs, } ) out_metrics = os.path.join(run_dir, "horizon_metrics.csv") out_pc = os.path.join(run_dir, "horizon_per_cause.csv") out_wy = os.path.join(run_dir, "workload_yield.csv") pd.DataFrame(per_tau_rows).to_csv(out_metrics, index=False) if per_cause_rows: pd.concat(per_cause_rows, ignore_index=True).to_csv(out_pc, index=False) else: pd.DataFrame(columns=["tau_years", "cause_id", "n_pos", "n_neg", "auc", "brier"]).to_csv(out_pc, index=False) pd.DataFrame(workload_rows).to_csv(out_wy, index=False) print(f"Wrote {out_metrics}") print(f"Wrote {out_pc}") print(f"Wrote {out_wy}") if __name__ == "__main__": main()