import argparse import os from typing import List import numpy as np import pandas as pd import torch from torch.utils.data import DataLoader try: from tqdm import tqdm # noqa: F401 except Exception: # pragma: no cover tqdm = None from utils import ( EvalRecordDataset, build_dataset_from_config, build_event_driven_records, build_model_head_criterion, eval_collate_fn, get_test_subset, make_inference_dataloader_kwargs, load_checkpoint_into, load_train_config, parse_float_list, predict_next_token_logits, get_auc_delong_var, seed_everything, DAYS_PER_YEAR, ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Evaluate next-event prediction using next-token scores" ) p.add_argument("--run_dir", type=str, required=True) p.add_argument( "--age_bins", type=str, nargs="+", default=["40", "45", "50", "55", "60", "65", "70", "inf"], help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)", ) p.add_argument( "--device", type=str, default=("cuda" if torch.cuda.is_available() else "cpu"), ) p.add_argument("--batch_size", type=int, default=256) p.add_argument("--num_workers", type=int, default=0) p.add_argument( "--max_cpu_cores", type=int, default=-1, help="Maximum number of CPU cores to use for parallel data construction.", ) p.add_argument("--seed", type=int, default=0) p.add_argument( "--min_pos", type=int, default=20, help="Minimum positives for per-cause AUC", ) p.add_argument( "--no_tqdm", action="store_true", help="Disable tqdm progress bars", ) return p.parse_args() def _format_age_bin_label(lo: float, hi: float) -> str: if np.isinf(hi): return f"[{lo}, inf)" return f"[{lo}, {hi})" def _compute_next_event_auc_clean_control( *, scores: np.ndarray, records: list, ) -> pd.DataFrame: """Delphi-2M next-event AUC (clean control) per cause. Definitions per cause k: - Case: next_event_cause == k - Control (clean): next_event_cause != k AND k not in record.lifetime_causes AUC is computed with DeLong variance. """ n_records = int(len(records)) if n_records == 0: return pd.DataFrame( columns=["cause_id", "n_case", "n_control", "auc", "auc_variance"], ) K = int(scores.shape[1]) y_next = np.array( [(-1 if r.next_event_cause is None else int(r.next_event_cause)) for r in records], dtype=np.int64, ) # Pre-compute lifetime disease membership matrix for vectorized clean-control filtering. # lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes. # Use a sparse matrix when SciPy is available to keep memory bounded for large K. row_parts: List[np.ndarray] = [] col_parts: List[np.ndarray] = [] for i, r in enumerate(records): causes = getattr(r, "lifetime_causes", None) if causes is None: continue causes = np.asarray(causes, dtype=np.int64) if causes.size == 0: continue # Keep only valid cause ids. m_valid = (causes >= 0) & (causes < K) if not np.any(m_valid): continue causes = causes[m_valid] row_parts.append(np.full((causes.size,), i, dtype=np.int32)) col_parts.append(causes.astype(np.int32, copy=False)) try: import scipy.sparse as sp # type: ignore if row_parts: rows = np.concatenate(row_parts, axis=0) cols = np.concatenate(col_parts, axis=0) data = np.ones((rows.size,), dtype=bool) lifetime_matrix = sp.csc_matrix( (data, (rows, cols)), shape=(n_records, K)) else: lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool) lifetime_is_sparse = True except Exception: # pragma: no cover lifetime_matrix = np.zeros((n_records, K), dtype=bool) for rows, cols in zip(row_parts, col_parts): lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype( np.int64, copy=False)] = True lifetime_is_sparse = False auc = np.full((K,), np.nan, dtype=np.float64) var = np.full((K,), np.nan, dtype=np.float64) n_case = np.zeros((K,), dtype=np.int64) n_control = np.zeros((K,), dtype=np.int64) for k in range(K): case_mask = y_next == k if not np.any(case_mask): continue # Clean controls: not next-event k AND never had k in their lifetime history. control_mask = y_next != k if np.any(control_mask): if lifetime_is_sparse: had_k = np.asarray(lifetime_matrix.getcol( k).toarray().ravel(), dtype=bool) else: had_k = lifetime_matrix[:, k] is_clean = ~had_k control_mask = control_mask & is_clean cs = scores[case_mask, k] hs = scores[control_mask, k] n_case[k] = int(cs.size) n_control[k] = int(hs.size) if cs.size == 0 or hs.size == 0: continue a, v = get_auc_delong_var(hs, cs) auc[k] = float(a) var[k] = float(v) return pd.DataFrame( { "cause_id": np.arange(K, dtype=np.int64), "n_case": n_case, "n_control": n_control, "auc": auc, "auc_variance": var, } ) def main() -> None: args = parse_args() seed_everything(args.seed) show_progress = (not args.no_tqdm) run_dir = args.run_dir cfg = load_train_config(run_dir) dataset = build_dataset_from_config(cfg) test_subset = get_test_subset(dataset, cfg) age_bins_years = parse_float_list(args.age_bins) records = build_event_driven_records( subset=test_subset, age_bins_years=age_bins_years, seed=args.seed, show_progress=show_progress, n_jobs=int(args.max_cpu_cores), ) device = torch.device(args.device) model, head, criterion = build_model_head_criterion(cfg, dataset, device) load_checkpoint_into(run_dir, model, head, criterion, device) rec_ds = EvalRecordDataset(test_subset, records) dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) loader = DataLoader( rec_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=eval_collate_fn, **dl_kwargs, ) scores = predict_next_token_logits( model, head, loader, device=device, show_progress=show_progress, progress_desc="Inference (next-token)", return_probs=True, ) y_next = np.array( [(-1 if r.next_event_cause is None else int(r.next_event_cause)) for r in records], dtype=np.int64, ) # Overall (preserve existing output files/shape) # Strict protocol: evaluate independently per age bin (no mixing). age_bins_years = np.asarray(age_bins_years, dtype=np.float64) age_bins_days = age_bins_years * DAYS_PER_YEAR # Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1}) t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64) bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1 per_bin_metric_rows: List[dict] = [] per_bin_auc_parts: List[pd.DataFrame] = [] for b in range(len(age_bins_years) - 1): lo = float(age_bins_years[b]) hi = float(age_bins_years[b + 1]) label = _format_age_bin_label(lo, hi) m = bin_idx == b m_scores = scores[m] m_records = [r for r, keep in zip(records, m) if bool(keep)] # Coverage metric for transparency (not Delphi-2M AUC itself). m_y = y_next[m] n_total = int(m_y.size) n_eligible = int((m_y >= 0).sum()) coverage = float(n_eligible / n_total) if n_total > 0 else 0.0 per_bin_metric_rows.append( {"age_bin": label, "metric": "n_records_total", "value": n_total}) per_bin_metric_rows.append( {"age_bin": label, "metric": "n_next_event_eligible", "value": n_eligible}) per_bin_metric_rows.append( {"age_bin": label, "metric": "coverage", "value": coverage}) m_auc = _compute_next_event_auc_clean_control( scores=m_scores, records=m_records, ) m_auc.insert(0, "age_bin", label) per_bin_auc_parts.append(m_auc) out_metrics_bins = os.path.join( run_dir, "next_event_metrics_by_age_bin.csv") pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False) out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv") if per_bin_auc_parts: pd.concat(per_bin_auc_parts, ignore_index=True).to_csv( out_auc_bins, index=False) else: pd.DataFrame(columns=["age_bin", "cause_id", "n_case", "n_control", "auc", "auc_variance"]).to_csv(out_auc_bins, index=False) print("PRIMARY METRICS: Per-cause AUC is reported per age bin using Delphi-2M clean controls.") print("EVAL METHOD: DeLong AUC variance is reported (per cause).") print(f"Wrote {out_metrics_bins}") print(f"Wrote {out_auc_bins}") if __name__ == "__main__": main()