import argparse import os from typing import List import numpy as np import pandas as pd import torch from torch.utils.data import DataLoader try: from tqdm import tqdm # noqa: F401 except Exception: # pragma: no cover tqdm = None from utils import ( EvalRecordDataset, build_dataset_from_config, build_event_driven_records, build_model_head_criterion, eval_collate_fn, get_test_subset, make_inference_dataloader_kwargs, load_checkpoint_into, load_train_config, parse_float_list, predict_cifs, roc_auc_ovr, seed_everything, topk_indices, ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Evaluate next-event prediction using short-window CIF" ) p.add_argument("--run_dir", type=str, required=True) p.add_argument("--tau_short", type=float, required=True, help="years") p.add_argument( "--age_bins", type=str, nargs="+", default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"], help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)", ) p.add_argument( "--device", type=str, default=("cuda" if torch.cuda.is_available() else "cpu"), ) p.add_argument("--batch_size", type=int, default=256) p.add_argument("--num_workers", type=int, default=0) p.add_argument("--seed", type=int, default=0) p.add_argument( "--min_pos", type=int, default=20, help="Minimum positives for per-cause AUC", ) p.add_argument( "--no_tqdm", action="store_true", help="Disable tqdm progress bars", ) return p.parse_args() def main() -> None: args = parse_args() seed_everything(args.seed) show_progress = (not args.no_tqdm) run_dir = args.run_dir cfg = load_train_config(run_dir) dataset = build_dataset_from_config(cfg) test_subset = get_test_subset(dataset, cfg) age_bins_years = parse_float_list(args.age_bins) records = build_event_driven_records( dataset=dataset, subset=test_subset, age_bins_years=age_bins_years, seed=args.seed, show_progress=show_progress, ) device = torch.device(args.device) model, head, criterion = build_model_head_criterion(cfg, dataset, device) load_checkpoint_into(run_dir, model, head, criterion, device) rec_ds = EvalRecordDataset(dataset, records) dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) loader = DataLoader( rec_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=eval_collate_fn, **dl_kwargs, ) tau = float(args.tau_short) scores = predict_cifs( model, head, criterion, loader, [tau], device=device, show_progress=show_progress, progress_desc="Inference (next-event)", ) # scores shape: (N,K,1) for multi-taus; squeeze last if scores.ndim == 3: scores = scores[:, :, 0] n_records_total = len(records) y_next = np.array( [(-1 if r.next_event_cause is None else int(r.next_event_cause)) for r in records], dtype=np.int64, ) eligible = y_next >= 0 n_eligible = int(eligible.sum()) coverage = float( n_eligible / n_records_total) if n_records_total > 0 else 0.0 metrics_rows: List[dict] = [] metrics_rows.append({"metric": "n_records_total", "value": n_records_total}) metrics_rows.append( {"metric": "n_next_event_eligible", "value": n_eligible}) metrics_rows.append({"metric": "coverage", "value": coverage}) metrics_rows.append({"metric": "tau_short_years", "value": tau}) if n_eligible == 0: out_path = os.path.join(run_dir, "next_event_metrics.csv") pd.DataFrame(metrics_rows).to_csv(out_path, index=False) print(f"No eligible records; wrote {out_path}") return scores_e = scores[eligible] y_e = y_next[eligible] pred = scores_e.argmax(axis=1) acc = float((pred == y_e).mean()) metrics_rows.append({"metric": "top1_accuracy", "value": acc}) # MRR order = np.argsort(-scores_e, axis=1, kind="mergesort") ranks = np.empty(y_e.shape[0], dtype=np.int32) for i in range(y_e.shape[0]): ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1 mrr = float((1.0 / ranks).mean()) metrics_rows.append({"metric": "mrr", "value": mrr}) # HitRate@K for k in [1, 3, 5, 10, 20]: topk = topk_indices(scores_e, k) hit = (topk == y_e[:, None]).any(axis=1) metrics_rows.append({"metric": f"hitrate_at_{k}", "value": float(hit.mean())}) # Macro OvR AUC per cause (optional) K = scores.shape[1] n_pos = np.bincount(y_e, minlength=K).astype(np.int64) n_neg = (int(y_e.size) - n_pos).astype(np.int64) auc = np.full((K,), np.nan, dtype=np.float64) min_pos = int(args.min_pos) candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0)) for k in candidates: auc_k = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k]) auc[k] = auc_k included = (n_pos >= min_pos) & (n_neg > 0) per_cause_df = pd.DataFrame( { "cause_id": np.arange(K, dtype=np.int64), "n_pos": n_pos, "n_neg": n_neg, "auc": auc, "included": included, } ) aucs = auc[np.isfinite(auc)] if aucs.size > 0: metrics_rows.append( {"metric": "macro_ovr_auc", "value": float(np.mean(aucs))}) else: metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) out_metrics = os.path.join(run_dir, "next_event_metrics.csv") pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False) # optional per-cause out_pc = os.path.join(run_dir, "next_event_per_cause.csv") per_cause_df.to_csv(out_pc, index=False) print(f"Wrote {out_metrics}") print(f"Wrote {out_pc}") if __name__ == "__main__": main()