import argparse import os from typing import List import numpy as np import pandas as pd import torch from torch.utils.data import DataLoader try: from tqdm import tqdm # noqa: F401 except Exception: # pragma: no cover tqdm = None from utils import ( EvalRecordDataset, build_dataset_from_config, build_event_driven_records, build_model_head_criterion, eval_collate_fn, get_test_subset, make_inference_dataloader_kwargs, load_checkpoint_into, load_train_config, parse_float_list, predict_cifs, roc_auc_ovr, seed_everything, topk_indices, DAYS_PER_YEAR, ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Evaluate next-event prediction using short-window CIF" ) p.add_argument("--run_dir", type=str, required=True) p.add_argument("--tau_short", type=float, required=True, help="years") p.add_argument( "--age_bins", type=str, nargs="+", default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"], help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)", ) p.add_argument( "--device", type=str, default=("cuda" if torch.cuda.is_available() else "cpu"), ) p.add_argument("--batch_size", type=int, default=256) p.add_argument("--num_workers", type=int, default=0) p.add_argument("--seed", type=int, default=0) p.add_argument( "--min_pos", type=int, default=20, help="Minimum positives for per-cause AUC", ) p.add_argument( "--no_tqdm", action="store_true", help="Disable tqdm progress bars", ) return p.parse_args() def _format_age_bin_label(lo: float, hi: float) -> str: if np.isinf(hi): return f"[{lo}, inf)" return f"[{lo}, {hi})" def _compute_next_event_metrics( *, scores: np.ndarray, y_next: np.ndarray, tau_short: float, min_pos: int, ) -> tuple[list[dict], pd.DataFrame]: """Compute next-event metrics on a given subset. Definitions are unchanged from the original script. """ n_records_total = int(y_next.size) eligible = y_next >= 0 n_eligible = int(eligible.sum()) coverage = float( n_eligible / n_records_total) if n_records_total > 0 else 0.0 metrics_rows: List[dict] = [] metrics_rows.append({"metric": "n_records_total", "value": n_records_total}) metrics_rows.append( {"metric": "n_next_event_eligible", "value": n_eligible}) metrics_rows.append({"metric": "coverage", "value": coverage}) metrics_rows.append( {"metric": "tau_short_years", "value": float(tau_short)}) K = int(scores.shape[1]) if n_records_total == 0: per_cause_df = pd.DataFrame( { "cause_id": np.arange(K, dtype=np.int64), "n_pos": np.zeros((K,), dtype=np.int64), "n_neg": np.zeros((K,), dtype=np.int64), "auc": np.full((K,), np.nan, dtype=np.float64), "included": np.zeros((K,), dtype=bool), } ) metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")}) metrics_rows.append({"metric": "mrr", "value": float("nan")}) for k in [1, 3, 5, 10, 20]: metrics_rows.append( {"metric": f"hitrate_at_{k}", "value": float("nan")}) metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) return metrics_rows, per_cause_df # If no eligible, keep coverage but leave accuracy-like metrics as NaN. if n_eligible == 0: per_cause_df = pd.DataFrame( { "cause_id": np.arange(K, dtype=np.int64), "n_pos": np.zeros((K,), dtype=np.int64), "n_neg": np.full((K,), n_records_total, dtype=np.int64), "auc": np.full((K,), np.nan, dtype=np.float64), "included": np.zeros((K,), dtype=bool), } ) metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")}) metrics_rows.append({"metric": "mrr", "value": float("nan")}) for k in [1, 3, 5, 10, 20]: metrics_rows.append( {"metric": f"hitrate_at_{k}", "value": float("nan")}) metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) return metrics_rows, per_cause_df scores_e = scores[eligible] y_e = y_next[eligible] pred = scores_e.argmax(axis=1) acc = float((pred == y_e).mean()) metrics_rows.append({"metric": "top1_accuracy", "value": acc}) # MRR order = np.argsort(-scores_e, axis=1, kind="mergesort") ranks = np.empty(y_e.shape[0], dtype=np.int32) for i in range(y_e.shape[0]): ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1 mrr = float((1.0 / ranks).mean()) metrics_rows.append({"metric": "mrr", "value": mrr}) # HitRate@K for k in [1, 3, 5, 10, 20]: topk = topk_indices(scores_e, k) hit = (topk == y_e[:, None]).any(axis=1) metrics_rows.append({"metric": f"hitrate_at_{k}", "value": float(hit.mean())}) # Macro OvR AUC per cause (optional) n_pos = np.bincount(y_e, minlength=K).astype(np.int64) n_neg = (int(y_e.size) - n_pos).astype(np.int64) auc = np.full((K,), np.nan, dtype=np.float64) candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0)) for k in candidates: auc[k] = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k]) included = (n_pos >= int(min_pos)) & (n_neg > 0) per_cause_df = pd.DataFrame( { "cause_id": np.arange(K, dtype=np.int64), "n_pos": n_pos, "n_neg": n_neg, "auc": auc, "included": included, } ) aucs = auc[np.isfinite(auc)] if aucs.size > 0: metrics_rows.append( {"metric": "macro_ovr_auc", "value": float(np.mean(aucs))}) else: metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) return metrics_rows, per_cause_df def main() -> None: args = parse_args() seed_everything(args.seed) show_progress = (not args.no_tqdm) run_dir = args.run_dir cfg = load_train_config(run_dir) dataset = build_dataset_from_config(cfg) test_subset = get_test_subset(dataset, cfg) age_bins_years = parse_float_list(args.age_bins) records = build_event_driven_records( subset=test_subset, age_bins_years=age_bins_years, seed=args.seed, show_progress=show_progress, ) device = torch.device(args.device) model, head, criterion = build_model_head_criterion(cfg, dataset, device) load_checkpoint_into(run_dir, model, head, criterion, device) rec_ds = EvalRecordDataset(test_subset, records) dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) loader = DataLoader( rec_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=eval_collate_fn, **dl_kwargs, ) tau = float(args.tau_short) scores = predict_cifs( model, head, criterion, loader, [tau], device=device, show_progress=show_progress, progress_desc="Inference (next-event)", ) # scores shape: (N,K,1) for multi-taus; squeeze last if scores.ndim == 3: scores = scores[:, :, 0] y_next = np.array( [(-1 if r.next_event_cause is None else int(r.next_event_cause)) for r in records], dtype=np.int64, ) # Overall (preserve existing output files/shape) metrics_rows, per_cause_df = _compute_next_event_metrics( scores=scores, y_next=y_next, tau_short=tau, min_pos=int(args.min_pos), ) out_metrics = os.path.join(run_dir, "next_event_metrics.csv") pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False) out_pc = os.path.join(run_dir, "next_event_per_cause.csv") per_cause_df.to_csv(out_pc, index=False) # By age bin (new outputs) age_bins_years = np.asarray(age_bins_years, dtype=np.float64) age_bins_days = age_bins_years * DAYS_PER_YEAR # Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1}) t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64) bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1 per_bin_metric_rows: List[dict] = [] per_bin_cause_parts: List[pd.DataFrame] = [] for b in range(len(age_bins_years) - 1): lo = float(age_bins_years[b]) hi = float(age_bins_years[b + 1]) label = _format_age_bin_label(lo, hi) m = bin_idx == b m_scores = scores[m] m_y = y_next[m] m_rows, m_pc = _compute_next_event_metrics( scores=m_scores, y_next=m_y, tau_short=tau, min_pos=int(args.min_pos), ) for row in m_rows: per_bin_metric_rows.append({"age_bin": label, **row}) m_pc = m_pc.copy() m_pc.insert(0, "age_bin", label) per_bin_cause_parts.append(m_pc) out_metrics_bins = os.path.join( run_dir, "next_event_metrics_by_age_bin.csv") pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False) out_pc_bins = os.path.join(run_dir, "next_event_per_cause_by_age_bin.csv") if per_bin_cause_parts: pd.concat(per_bin_cause_parts, ignore_index=True).to_csv( out_pc_bins, index=False) else: pd.DataFrame(columns=["age_bin", "cause_id", "n_pos", "n_neg", "auc", "included"]).to_csv(out_pc_bins, index=False) print(f"Wrote {out_metrics}") print(f"Wrote {out_pc}") print(f"Wrote {out_metrics_bins}") print(f"Wrote {out_pc_bins}") if __name__ == "__main__": main()