From b80d9a4256a6bc5c0daf0fdd42f166f81442af3f Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sun, 18 Jan 2026 15:41:07 +0800 Subject: [PATCH] Remove evaluate_next_event.py and utils.py files to streamline the codebase. These files contained functions and classes related to evaluation and utility operations that are no longer needed. --- evaluate_horizon.py | 455 -------------------- evaluate_next_event.py | 311 -------------- utils.py | 925 ----------------------------------------- 3 files changed, 1691 deletions(-) delete mode 100644 evaluate_horizon.py delete mode 100644 evaluate_next_event.py delete mode 100644 utils.py diff --git a/evaluate_horizon.py b/evaluate_horizon.py deleted file mode 100644 index a288f9a..0000000 --- a/evaluate_horizon.py +++ /dev/null @@ -1,455 +0,0 @@ -"""Horizon-capture evaluation (event-driven, age-stratified). - -This script implements the protocol described in 评估方案.md: - -- Age-stratified evaluation: metrics are computed independently within each age bin (no mixing). -- Event-driven inclusion: each (person, age_bin) yields a record iff DOA <= bin upper bound and - there is at least one disease event in the bin; baseline t0 is sampled randomly from in-bin - disease events with t0 >= DOA. -- No follow-up completeness filtering (no t0+tau <= t_end constraint). - -Primary outputs per age bin: -- Top-K Event Capture@tau (event-count based) -- Workload–Yield curves (Top-p% people by a person-level horizon score) - -Secondary (diagnostic-only) outputs per age bin: -- Approximate event-driven AUC / Brier (no IPCW, no censoring adjustment) -""" - -import argparse -import math -import os -from typing import Dict, List, Sequence, Tuple - -import numpy as np -import pandas as pd -import torch -from torch.utils.data import DataLoader - -try: - from tqdm import tqdm -except Exception: # pragma: no cover - tqdm = None - -from utils import ( - EvalRecordDataset, - build_dataset_from_config, - build_event_driven_records, - build_model_head_criterion, - eval_collate_fn, - flatten_future_events, - get_test_subset, - load_checkpoint_into, - load_train_config, - make_inference_dataloader_kwargs, - parse_float_list, - predict_cifs, - roc_auc_ovr, - seed_everything, - topk_indices, - DAYS_PER_YEAR, -) - - -def parse_args() -> argparse.Namespace: - p = argparse.ArgumentParser( - description="Evaluate horizon-capture using CIF at horizons") - p.add_argument("--run_dir", type=str, required=True) - p.add_argument( - "--horizons", - type=str, - nargs="+", - default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"], - help="Horizon grid in years", - ) - p.add_argument( - "--age_bins", - type=str, - nargs="+", - default=["40", "45", "50", "55", "60", "65", "70", "inf"], - help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)", - ) - - p.add_argument( - "--device", - type=str, - default=("cuda" if torch.cuda.is_available() else "cpu"), - ) - p.add_argument("--batch_size", type=int, default=256) - p.add_argument("--num_workers", type=int, default=0) - p.add_argument( - "--max_cpu_cores", - type=int, - default=-1, - help="Maximum number of CPU cores to use for parallel data construction.", - ) - p.add_argument("--seed", type=int, default=0) - p.add_argument("--min_pos", type=int, default=20) - p.add_argument( - "--topk_list", - type=int, - nargs="+", - default=[5, 10, 20, 50], - ) - p.add_argument( - "--workload_fracs", - type=float, - nargs="+", - default=[0.01, 0.02, 0.05, 0.1, 0.2, 0.5], - help="Fractions for workload–yield curves (Top-p%% people).", - ) - p.add_argument( - "--no_tqdm", - action="store_true", - help="Disable tqdm progress bars", - ) - return p.parse_args() - - -def _format_age_bin_label(lo: float, hi: float) -> str: - if np.isinf(hi): - return f"[{lo}, inf)" - return f"[{lo}, {hi})" - - -def _assign_age_bin_idx(t0_days: np.ndarray, age_bins_years: Sequence[float]) -> np.ndarray: - age_bins_years = np.asarray(list(age_bins_years), dtype=np.float64) - age_bins_days = age_bins_years * DAYS_PER_YEAR - return np.searchsorted(age_bins_days, t0_days, side="left") - 1 - - -def _event_counts_within_tau( - n_records: int, - event_record_idx: np.ndarray, - event_dt_years: np.ndarray, - tau_years: float, -) -> np.ndarray: - """Count events within (t0, t0+tau] per record (event-count, not unique causes).""" - if event_record_idx.size == 0: - return np.zeros((n_records,), dtype=np.int64) - m = event_dt_years <= float(tau_years) - if not np.any(m): - return np.zeros((n_records,), dtype=np.int64) - return np.bincount(event_record_idx[m], minlength=n_records).astype(np.int64) - - -def build_labels_within_tau_flat( - n_records: int, - n_causes: int, - event_record_idx: np.ndarray, - event_cause: np.ndarray, - event_dt_years: np.ndarray, - tau_years: float, -) -> np.ndarray: - """Build y_within_tau using a flattened (record,cause,dt) representation. - - This preserves the exact label definition: y[i,k]=1 iff at least one event of cause k - occurs in (t0, t0+tau]. - """ - y = np.zeros((n_records, n_causes), dtype=np.int8) - if event_dt_years.size == 0: - return y - m = event_dt_years <= float(tau_years) - if not np.any(m): - return y - y[event_record_idx[m], event_cause[m]] = 1 - return y - - -def main() -> None: - args = parse_args() - seed_everything(args.seed) - - show_progress = (not args.no_tqdm) - - run_dir = args.run_dir - cfg = load_train_config(run_dir) - - dataset = build_dataset_from_config(cfg) - test_subset = get_test_subset(dataset, cfg) - - age_bins_years = parse_float_list(args.age_bins) - horizons = parse_float_list(args.horizons) - horizons = [float(h) for h in horizons] - - records = build_event_driven_records( - subset=test_subset, - age_bins_years=age_bins_years, - seed=args.seed, - show_progress=show_progress, - n_jobs=int(args.max_cpu_cores), - ) - - device = torch.device(args.device) - model, head, criterion = build_model_head_criterion(cfg, dataset, device) - load_checkpoint_into(run_dir, model, head, criterion, device) - - rec_ds = EvalRecordDataset(test_subset, records) - - dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) - loader = DataLoader( - rec_ds, - batch_size=args.batch_size, - shuffle=False, - num_workers=args.num_workers, - collate_fn=eval_collate_fn, - **dl_kwargs, - ) - - # Print disclaimers every run (requested) - print("PRIMARY METRICS: event-count Capture@K and Workload–Yield, computed independently per age bin.") - print("DIAGNOSTICS ONLY: AUC/Brier below are event-driven approximations (no IPCW / censoring adjustment).") - - scores = predict_cifs( - model, - head, - criterion, - loader, - horizons, - device=device, - show_progress=show_progress, - progress_desc="Inference (horizons)", - ) - # scores shape: (N, K, H) - if scores.ndim != 3: - raise ValueError( - f"Expected CIF scores with shape (N,K,H), got {scores.shape}") - - N, K, H = scores.shape - if N != len(records): - raise ValueError("Record count mismatch") - - # Pre-flatten all future events once to avoid repeated per-record scans. - # NOTE: these are event-level arrays (not unique causes), suitable for event-count Capture@K. - evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K) - - # Assign each record to an age bin (based on t0; by construction t0 is within the bin). - t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64) - bin_idx = _assign_age_bin_idx(t0_days, age_bins_years) - age_bins_years_arr = np.asarray(list(age_bins_years), dtype=np.float64) - - capture_rows: List[Dict[str, object]] = [] - workload_rows: List[Dict[str, object]] = [] - - # Diagnostics (optional): approximate event-driven AUC/Brier computed per bin. - diag_rows: List[Dict[str, object]] = [] - diag_per_cause_parts: List[pd.DataFrame] = [] - - bins_iter = range(len(age_bins_years_arr) - 1) - if show_progress and tqdm is not None: - bins_iter = tqdm(bins_iter, total=len( - age_bins_years_arr) - 1, desc="Age bins") - - for b in bins_iter: - lo = float(age_bins_years_arr[b]) - hi = float(age_bins_years_arr[b + 1]) - age_label = _format_age_bin_label(lo, hi) - - m_rec = bin_idx == b - n_bin = int(m_rec.sum()) - if n_bin == 0: - continue - - rec_idx_bin = np.flatnonzero(m_rec).astype(np.int32) - - # Filter events to this bin's records once. - m_evt_bin = m_rec[evt_rec_idx] if evt_rec_idx.size > 0 else np.zeros( - (0,), dtype=bool) - evt_rec_idx_b = evt_rec_idx[m_evt_bin] - evt_cause_b = evt_cause[m_evt_bin] - evt_dt_b = evt_dt[m_evt_bin] - - horizon_iter = enumerate(horizons) - if show_progress and tqdm is not None: - horizon_iter = tqdm(horizon_iter, total=len( - horizons), desc=f"Horizons {age_label}") - - # Precompute a local index mapping for diagnostics label building. - local_map = np.full((N,), -1, dtype=np.int32) - local_map[rec_idx_bin] = np.arange(n_bin, dtype=np.int32) - - for h_idx, tau in horizon_iter: - s_tau_all = scores[:, :, h_idx] - s_tau = s_tau_all[m_rec] - - # ------------------------- - # Primary metric: Top-K Event Capture@tau (event-count based) - # ------------------------- - denom_events = int(np.sum(evt_dt_b <= float(tau)) - ) if evt_dt_b.size > 0 else 0 - if denom_events == 0: - for topk in args.topk_list: - capture_rows.append( - { - "age_bin": age_label, - "tau_years": float(tau), - "topk": int(topk), - "capture_at_k": float("nan"), - "denom_events": int(0), - "numer_events": int(0), - "n_records": int(n_bin), - "n_causes": int(K), - } - ) - else: - m_evt_tau = evt_dt_b <= float(tau) - evt_rec_idx_tau = evt_rec_idx_b[m_evt_tau] - evt_cause_tau = evt_cause_b[m_evt_tau] - - # For each K, compute whether each event's cause is in that record's Top-K list. - for topk in args.topk_list: - topk = int(topk) - idx = topk_indices(s_tau_all, topk) # shape (N, topk) - idx_for_events = idx[evt_rec_idx_tau] - hits = (idx_for_events == - evt_cause_tau[:, None]).any(axis=1) - numer_events = int(hits.sum()) - capture = float(numer_events / denom_events) - capture_rows.append( - { - "age_bin": age_label, - "tau_years": float(tau), - "topk": int(topk), - "capture_at_k": capture, - "denom_events": int(denom_events), - "numer_events": int(numer_events), - "n_records": int(n_bin), - "n_causes": int(K), - } - ) - - # ------------------------- - # Primary metric: Workload–Yield (Top-p% people) - # ------------------------- - # Person-level score: max_k CIF_k(tau). This is used only for workload–yield ranking. - person_score = s_tau.max(axis=1) if K > 0 else np.zeros( - (n_bin,), dtype=np.float64) - order = np.argsort(-person_score, kind="mergesort") - - counts_per_record = _event_counts_within_tau( - n_bin, local_map[evt_rec_idx_b], evt_dt_b, tau) - total_events = int(counts_per_record.sum()) - overall_events_per_person = ( - total_events / float(n_bin)) if n_bin > 0 else float("nan") - - for frac in args.workload_fracs: - frac = float(frac) - if frac <= 0.0: - continue - n_sel = int(math.ceil(frac * n_bin)) - n_sel = min(max(n_sel, 1), n_bin) - sel_local = order[:n_sel] - events_captured = int(counts_per_record[sel_local].sum()) - capture_rate = float( - events_captured / total_events) if total_events > 0 else float("nan") - - selected_events_per_person = ( - events_captured / float(n_sel)) if n_sel > 0 else float("nan") - lift = (selected_events_per_person / - overall_events_per_person) if overall_events_per_person > 0 else float("nan") - - workload_rows.append( - { - "age_bin": age_label, - "tau_years": float(tau), - "frac_selected": float(frac), - "n_selected": int(n_sel), - "n_records": int(n_bin), - "total_events": int(total_events), - "events_captured": int(events_captured), - "capture_rate": capture_rate, - "lift_events_per_person": float(lift), - "person_score_def": "max_k_CIF_k(tau)", - } - ) - - # ------------------------- - # Diagnostics (optional): approximate event-driven AUC/Brier - # ------------------------- - # Convert event-level data to binary labels y[i,k]=1 iff >=1 event of cause k within tau. - y_tau_bin = np.zeros((n_bin, K), dtype=np.int8) - if evt_dt_b.size > 0: - m_evt_tau = evt_dt_b <= float(tau) - if np.any(m_evt_tau): - rec_local = local_map[evt_rec_idx_b[m_evt_tau]] - valid = rec_local >= 0 - y_tau_bin[rec_local[valid], - evt_cause_b[m_evt_tau][valid]] = 1 - - n_pos = y_tau_bin.sum(axis=0).astype(np.int64) - n_neg = (int(n_bin) - n_pos).astype(np.int64) - - brier_per_cause = np.mean( - (y_tau_bin.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0 - ) - brier_macro = float(np.mean(brier_per_cause) - ) if K > 0 else float("nan") - brier_weighted = float(np.sum( - brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan") - - auc = np.full((K,), np.nan, dtype=np.float64) - min_pos = int(args.min_pos) - candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0)) - for k in candidates: - auc[k] = roc_auc_ovr(y_tau_bin[:, k].astype( - np.int32), s_tau[:, k].astype(np.float64)) - - finite_auc = auc[np.isfinite(auc)] - auc_macro = float(np.mean(finite_auc) - ) if finite_auc.size > 0 else float("nan") - w_mask = np.isfinite(auc) - auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum( - n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan") - n_valid_auc = int(np.isfinite(auc).sum()) - - diag_rows.append( - { - "age_bin": age_label, - "tau_years": float(tau), - "n_records": int(n_bin), - "n_causes": int(K), - "auc_macro": auc_macro, - "auc_weighted_by_npos": auc_weighted, - "n_causes_valid_auc": int(n_valid_auc), - "brier_macro": brier_macro, - "brier_weighted_by_npos": brier_weighted, - } - ) - - diag_per_cause_parts.append( - pd.DataFrame( - { - "age_bin": age_label, - "tau_years": float(tau), - "cause_id": np.arange(K, dtype=np.int64), - "n_pos": n_pos, - "n_neg": n_neg, - "auc": auc, - "brier": brier_per_cause, - } - ) - ) - - out_capture = os.path.join(run_dir, "horizon_capture.csv") - out_wy = os.path.join(run_dir, "workload_yield.csv") - out_diag = os.path.join(run_dir, "horizon_metrics.csv") - out_diag_pc = os.path.join(run_dir, "horizon_per_cause.csv") - - pd.DataFrame(capture_rows).to_csv(out_capture, index=False) - pd.DataFrame(workload_rows).to_csv(out_wy, index=False) - pd.DataFrame(diag_rows).to_csv(out_diag, index=False) - if diag_per_cause_parts: - pd.concat(diag_per_cause_parts, ignore_index=True).to_csv( - out_diag_pc, index=False) - else: - pd.DataFrame(columns=["age_bin", "tau_years", "cause_id", "n_pos", - "n_neg", "auc", "brier"]).to_csv(out_diag_pc, index=False) - - print(f"Wrote {out_capture}") - print(f"Wrote {out_wy}") - print(f"Wrote {out_diag} (diagnostic-only)") - print(f"Wrote {out_diag_pc} (diagnostic-only)") - - -if __name__ == "__main__": - main() diff --git a/evaluate_next_event.py b/evaluate_next_event.py deleted file mode 100644 index 7858366..0000000 --- a/evaluate_next_event.py +++ /dev/null @@ -1,311 +0,0 @@ -import argparse -import os -from typing import List - -import numpy as np -import pandas as pd -import torch -from torch.utils.data import DataLoader - -try: - from tqdm import tqdm # noqa: F401 -except Exception: # pragma: no cover - tqdm = None - -from utils import ( - EvalRecordDataset, - build_dataset_from_config, - build_event_driven_records, - build_model_head_criterion, - eval_collate_fn, - get_test_subset, - make_inference_dataloader_kwargs, - load_checkpoint_into, - load_train_config, - parse_float_list, - predict_next_token_logits, - get_auc_delong_var, - seed_everything, - DAYS_PER_YEAR, -) - - -def parse_args() -> argparse.Namespace: - p = argparse.ArgumentParser( - description="Evaluate next-event prediction using next-token scores" - ) - p.add_argument("--run_dir", type=str, required=True) - p.add_argument( - "--age_bins", - type=str, - nargs="+", - default=["40", "45", "50", "55", "60", "65", "70", "inf"], - help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)", - ) - - p.add_argument( - "--device", - type=str, - default=("cuda" if torch.cuda.is_available() else "cpu"), - ) - p.add_argument("--batch_size", type=int, default=256) - p.add_argument("--num_workers", type=int, default=0) - p.add_argument( - "--max_cpu_cores", - type=int, - default=-1, - help="Maximum number of CPU cores to use for parallel data construction.", - ) - p.add_argument("--seed", type=int, default=0) - p.add_argument( - "--min_pos", - type=int, - default=20, - help="Minimum positives for per-cause AUC", - ) - p.add_argument( - "--no_tqdm", - action="store_true", - help="Disable tqdm progress bars", - ) - return p.parse_args() - - -def _format_age_bin_label(lo: float, hi: float) -> str: - if np.isinf(hi): - return f"[{lo}, inf)" - return f"[{lo}, {hi})" - - -def _compute_next_event_auc_clean_control( - *, - scores: np.ndarray, - records: list, -) -> pd.DataFrame: - """Delphi-2M next-event AUC (clean control) per cause. - - Definitions per cause k: - - Case: next_event_cause == k - - Control (clean): next_event_cause != k AND k not in record.lifetime_causes - AUC is computed with DeLong variance. - """ - n_records = int(len(records)) - if n_records == 0: - return pd.DataFrame( - columns=["cause_id", "n_case", "n_control", "auc", "auc_variance"], - ) - - K = int(scores.shape[1]) - y_next = np.array( - [(-1 if r.next_event_cause is None else int(r.next_event_cause)) - for r in records], - dtype=np.int64, - ) - - # Pre-compute lifetime disease membership matrix for vectorized clean-control filtering. - # lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes. - # Use a sparse matrix when SciPy is available to keep memory bounded for large K. - row_parts: List[np.ndarray] = [] - col_parts: List[np.ndarray] = [] - for i, r in enumerate(records): - causes = getattr(r, "lifetime_causes", None) - if causes is None: - continue - causes = np.asarray(causes, dtype=np.int64) - if causes.size == 0: - continue - # Keep only valid cause ids. - m_valid = (causes >= 0) & (causes < K) - if not np.any(m_valid): - continue - causes = causes[m_valid] - row_parts.append(np.full((causes.size,), i, dtype=np.int32)) - col_parts.append(causes.astype(np.int32, copy=False)) - - try: - import scipy.sparse as sp # type: ignore - - if row_parts: - rows = np.concatenate(row_parts, axis=0) - cols = np.concatenate(col_parts, axis=0) - data = np.ones((rows.size,), dtype=bool) - lifetime_matrix = sp.csc_matrix( - (data, (rows, cols)), shape=(n_records, K)) - else: - lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool) - lifetime_is_sparse = True - except Exception: # pragma: no cover - lifetime_matrix = np.zeros((n_records, K), dtype=bool) - for rows, cols in zip(row_parts, col_parts): - lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype( - np.int64, copy=False)] = True - lifetime_is_sparse = False - - auc = np.full((K,), np.nan, dtype=np.float64) - var = np.full((K,), np.nan, dtype=np.float64) - n_case = np.zeros((K,), dtype=np.int64) - n_control = np.zeros((K,), dtype=np.int64) - - for k in range(K): - case_mask = y_next == k - if not np.any(case_mask): - continue - - # Clean controls: not next-event k AND never had k in their lifetime history. - control_mask = y_next != k - if np.any(control_mask): - if lifetime_is_sparse: - had_k = np.asarray(lifetime_matrix.getcol( - k).toarray().ravel(), dtype=bool) - else: - had_k = lifetime_matrix[:, k] - is_clean = ~had_k - control_mask = control_mask & is_clean - - cs = scores[case_mask, k] - hs = scores[control_mask, k] - n_case[k] = int(cs.size) - n_control[k] = int(hs.size) - if cs.size == 0 or hs.size == 0: - continue - - a, v = get_auc_delong_var(hs, cs) - auc[k] = float(a) - var[k] = float(v) - - return pd.DataFrame( - { - "cause_id": np.arange(K, dtype=np.int64), - "n_case": n_case, - "n_control": n_control, - "auc": auc, - "auc_variance": var, - } - ) - - -def main() -> None: - args = parse_args() - - # Best-effort control of implicit parallelism to avoid CPU oversubscription. - # Note: environment variables are ideally set before importing NumPy/PyTorch, - # but setting them early in main can still affect subprocesses or lazy readers. - if int(args.max_cpu_cores) > 0: - n_threads = int(args.max_cpu_cores) - torch.set_num_threads(n_threads) - for k in ( - "OMP_NUM_THREADS", - "MKL_NUM_THREADS", - "OPENBLAS_NUM_THREADS", - "VECLIB_MAXIMUM_THREADS", - "NUMEXPR_NUM_THREADS", - ): - os.environ[k] = str(n_threads) - print(f"Restricting implicit parallelism to {n_threads} threads.") - seed_everything(args.seed) - - show_progress = (not args.no_tqdm) - - run_dir = args.run_dir - cfg = load_train_config(run_dir) - - dataset = build_dataset_from_config(cfg) - test_subset = get_test_subset(dataset, cfg) - - age_bins_years = parse_float_list(args.age_bins) - records = build_event_driven_records( - subset=test_subset, - age_bins_years=age_bins_years, - seed=args.seed, - show_progress=show_progress, - n_jobs=int(args.max_cpu_cores), - ) - - device = torch.device(args.device) - model, head, criterion = build_model_head_criterion(cfg, dataset, device) - load_checkpoint_into(run_dir, model, head, criterion, device) - - rec_ds = EvalRecordDataset(test_subset, records) - - dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers) - loader = DataLoader( - rec_ds, - batch_size=args.batch_size, - shuffle=False, - num_workers=args.num_workers, - collate_fn=eval_collate_fn, - **dl_kwargs, - ) - - scores = predict_next_token_logits( - model, - head, - loader, - device=device, - show_progress=show_progress, - progress_desc="Inference (next-token)", - return_probs=True, - ) - - y_next = np.array( - [(-1 if r.next_event_cause is None else int(r.next_event_cause)) - for r in records], - dtype=np.int64, - ) - - # Overall (preserve existing output files/shape) - # Strict protocol: evaluate independently per age bin (no mixing). - age_bins_years = np.asarray(age_bins_years, dtype=np.float64) - age_bins_days = age_bins_years * DAYS_PER_YEAR - # Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1}) - t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64) - bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1 - - per_bin_metric_rows: List[dict] = [] - per_bin_auc_parts: List[pd.DataFrame] = [] - for b in range(len(age_bins_years) - 1): - lo = float(age_bins_years[b]) - hi = float(age_bins_years[b + 1]) - label = _format_age_bin_label(lo, hi) - m = bin_idx == b - m_scores = scores[m] - m_records = [r for r, keep in zip(records, m) if bool(keep)] - # Coverage metric for transparency (not Delphi-2M AUC itself). - m_y = y_next[m] - n_total = int(m_y.size) - n_eligible = int((m_y >= 0).sum()) - coverage = float(n_eligible / n_total) if n_total > 0 else 0.0 - per_bin_metric_rows.append( - {"age_bin": label, "metric": "n_records_total", "value": n_total}) - per_bin_metric_rows.append( - {"age_bin": label, "metric": "n_next_event_eligible", "value": n_eligible}) - per_bin_metric_rows.append( - {"age_bin": label, "metric": "coverage", "value": coverage}) - - m_auc = _compute_next_event_auc_clean_control( - scores=m_scores, - records=m_records, - ) - m_auc.insert(0, "age_bin", label) - per_bin_auc_parts.append(m_auc) - - out_metrics_bins = os.path.join( - run_dir, "next_event_metrics_by_age_bin.csv") - pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False) - - out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv") - if per_bin_auc_parts: - pd.concat(per_bin_auc_parts, ignore_index=True).to_csv( - out_auc_bins, index=False) - else: - pd.DataFrame(columns=["age_bin", "cause_id", "n_case", "n_control", - "auc", "auc_variance"]).to_csv(out_auc_bins, index=False) - - print("PRIMARY METRICS: Per-cause AUC is reported per age bin using Delphi-2M clean controls.") - print("EVAL METHOD: DeLong AUC variance is reported (per cause).") - print(f"Wrote {out_metrics_bins}") - print(f"Wrote {out_auc_bins}") - - -if __name__ == "__main__": - main() diff --git a/utils.py b/utils.py deleted file mode 100644 index d634382..0000000 --- a/utils.py +++ /dev/null @@ -1,925 +0,0 @@ -import json -import math -import os -import random -import re -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Sequence, Tuple - -import numpy as np -import torch -from torch.utils.data import DataLoader, Dataset, Subset, random_split - -try: - from tqdm import tqdm as _tqdm -except Exception: # pragma: no cover - _tqdm = None - -try: - from joblib import Parallel, delayed # type: ignore -except Exception: # pragma: no cover - Parallel = None - delayed = None - -from dataset import HealthDataset -from losses import ( - DiscreteTimeCIFNLLLoss, - ExponentialNLLLoss, - PiecewiseExponentialCIFNLLLoss, -) -from model import DelphiFork, SapDelphi, SimpleHead - - -DAYS_PER_YEAR = 365.25 -N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2 - - -def _progress(iterable, *, enabled: bool, desc: str, total: Optional[int] = None): - if enabled and _tqdm is not None: - return _tqdm(iterable, desc=desc, total=total) - return iterable - - -def make_inference_dataloader_kwargs( - device: torch.device, - num_workers: int, -) -> Dict[str, Any]: - """DataLoader kwargs tuned for inference throughput. - - Behavior/metrics are unchanged; this only impacts speed. - """ - use_cuda = device.type == "cuda" and torch.cuda.is_available() - kwargs: Dict[str, Any] = { - "pin_memory": bool(use_cuda), - } - if num_workers > 0: - kwargs["persistent_workers"] = True - # default prefetch is 2; set explicitly for clarity. - kwargs["prefetch_factor"] = 2 - return kwargs - - -# ------------------------- -# Config + determinism -# ------------------------- - -def _replace_nonstandard_json_numbers(text: str) -> str: - # Python's json.dump writes Infinity/-Infinity/NaN for non-finite floats. - # Replace bare tokens (not within quotes) with string placeholders. - def repl(match: re.Match[str]) -> str: - token = match.group(0) - if token == "-Infinity": - return '"__NINF__"' - if token == "Infinity": - return '"__INF__"' - if token == "NaN": - return '"__NAN__"' - return token - - return re.sub(r'(? Any: - if isinstance(obj, dict): - return {k: _restore_placeholders(v) for k, v in obj.items()} - if isinstance(obj, list): - return [_restore_placeholders(v) for v in obj] - if obj == "__INF__": - return float("inf") - if obj == "__NINF__": - return float("-inf") - if obj == "__NAN__": - return float("nan") - return obj - - -def load_train_config(run_dir: str) -> Dict[str, Any]: - cfg_path = os.path.join(run_dir, "train_config.json") - with open(cfg_path, "r", encoding="utf-8") as f: - raw = f.read() - raw = _replace_nonstandard_json_numbers(raw) - cfg = json.loads(raw) - cfg = _restore_placeholders(cfg) - return cfg - - -def seed_everything(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def parse_float_list(values: Sequence[str]) -> List[float]: - out: List[float] = [] - for v in values: - s = str(v).strip().lower() - if s in {"inf", "+inf", "infty", "infinity", "+infinity"}: - out.append(float("inf")) - elif s in {"-inf", "-infty", "-infinity"}: - out.append(float("-inf")) - else: - out.append(float(v)) - return out - - -# ------------------------- -# Dataset + split (match train.py) -# ------------------------- - -def build_dataset_from_config(cfg: Dict[str, Any]) -> HealthDataset: - data_prefix = cfg["data_prefix"] - full_cov = bool(cfg.get("full_cov", False)) - - if full_cov: - cov_list = None - else: - cov_list = ["bmi", "smoking", "alcohol"] - - dataset = HealthDataset( - data_prefix=data_prefix, - covariate_list=cov_list, - ) - return dataset - - -def get_test_subset(dataset: HealthDataset, cfg: Dict[str, Any]) -> Subset: - n_total = len(dataset) - train_ratio = float(cfg["train_ratio"]) - val_ratio = float(cfg["val_ratio"]) - seed = int(cfg["random_seed"]) - - n_train = int(n_total * train_ratio) - n_val = int(n_total * val_ratio) - n_test = n_total - n_train - n_val - - _, _, test_subset = random_split( - dataset, - [n_train, n_val, n_test], - generator=torch.Generator().manual_seed(seed), - ) - return test_subset - - -# ------------------------- -# Model + head + criterion (match train.py) -# ------------------------- - -def build_model_head_criterion( - cfg: Dict[str, Any], - dataset: HealthDataset, - device: torch.device, -) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]: - loss_type = cfg["loss_type"] - - if loss_type == "exponential": - criterion = ExponentialNLLLoss(lambda_reg=float( - cfg.get("lambda_reg", 0.0))).to(device) - out_dims = [dataset.n_disease] - elif loss_type == "discrete_time_cif": - bin_edges = [float(x) for x in cfg["bin_edges"]] - criterion = DiscreteTimeCIFNLLLoss( - bin_edges=bin_edges, - lambda_reg=float(cfg.get("lambda_reg", 0.0)), - ).to(device) - out_dims = [dataset.n_disease + 1, len(bin_edges)] - elif loss_type == "pwe_cif": - # training drops +inf for PWE - raw_edges = [float(x) for x in cfg["bin_edges"]] - pwe_edges = [float(x) for x in raw_edges if math.isfinite(float(x))] - if len(pwe_edges) < 2: - raise ValueError( - "pwe_cif requires at least 2 finite bin edges (including 0). " - f"Got bin_edges={raw_edges}" - ) - if float(pwe_edges[0]) != 0.0: - raise ValueError( - f"pwe_cif requires bin_edges[0]==0.0; got {pwe_edges[0]}") - - criterion = PiecewiseExponentialCIFNLLLoss( - bin_edges=pwe_edges, - lambda_reg=float(cfg.get("lambda_reg", 0.0)), - ).to(device) - n_bins = len(pwe_edges) - 1 - out_dims = [dataset.n_disease, n_bins] - else: - raise ValueError(f"Unsupported loss_type: {loss_type}") - - model_type = cfg["model_type"] - if model_type == "delphi_fork": - model = DelphiFork( - n_disease=dataset.n_disease, - n_tech_tokens=N_TECH_TOKENS, - n_embd=int(cfg["n_embd"]), - n_head=int(cfg["n_head"]), - n_layer=int(cfg["n_layer"]), - pdrop=float(cfg.get("pdrop", 0.0)), - age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), - n_cont=int(dataset.n_cont), - n_cate=int(dataset.n_cate), - cate_dims=list(dataset.cate_dims), - ).to(device) - elif model_type == "sap_delphi": - model = SapDelphi( - n_disease=dataset.n_disease, - n_tech_tokens=N_TECH_TOKENS, - n_embd=int(cfg["n_embd"]), - n_head=int(cfg["n_head"]), - n_layer=int(cfg["n_layer"]), - pdrop=float(cfg.get("pdrop", 0.0)), - age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")), - n_cont=int(dataset.n_cont), - n_cate=int(dataset.n_cate), - cate_dims=list(dataset.cate_dims), - pretrained_weights_path=str( - cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")), - freeze_embeddings=True, - ).to(device) - else: - raise ValueError(f"Unsupported model_type: {model_type}") - - head = SimpleHead( - n_embd=int(cfg["n_embd"]), - out_dims=list(out_dims), - ).to(device) - - return model, head, criterion - - -def load_checkpoint_into( - run_dir: str, - model: torch.nn.Module, - head: torch.nn.Module, - criterion: Optional[torch.nn.Module], - device: torch.device, -) -> Dict[str, Any]: - ckpt_path = os.path.join(run_dir, "best_model.pt") - ckpt = torch.load(ckpt_path, map_location=device) - model.load_state_dict(ckpt["model_state_dict"], strict=True) - head.load_state_dict(ckpt["head_state_dict"], strict=True) - if criterion is not None and "criterion_state_dict" in ckpt: - try: - criterion.load_state_dict( - ckpt["criterion_state_dict"], strict=False) - except Exception: - # Criterion state is not essential for inference. - pass - return ckpt - - -# ------------------------- -# Evaluation record construction (event-driven) -# ------------------------- - -@dataclass(frozen=True) -class EvalRecord: - subset_idx: int - doa_days: float - t0_days: float - cutoff_pos: int # baseline position (inclusive) - next_event_cause: Optional[int] - next_event_dt_years: Optional[float] - # (U,) unique causes ever observed (clean-control filtering) - lifetime_causes: np.ndarray - future_causes: np.ndarray # (E,) in [0..K-1] - future_dt_years: np.ndarray # (E,) strictly > 0 - - -def _to_days(x_years: float) -> float: - if math.isinf(float(x_years)): - return float("inf") - return float(x_years) * DAYS_PER_YEAR - - -def build_event_driven_records( - subset: Subset, - age_bins_years: Sequence[float], - seed: int, - show_progress: bool = False, - n_jobs: int = 1, - chunk_size: int = 256, - prefer: str = "threads", -) -> List[EvalRecord]: - if len(age_bins_years) < 2: - raise ValueError("age_bins must have at least 2 boundaries") - - age_bins_days = [_to_days(b) for b in age_bins_years] - if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)): - raise ValueError("age_bins must be strictly increasing") - - def _iter_chunks(n: int, size: int) -> List[np.ndarray]: - if size <= 0: - raise ValueError("chunk_size must be >= 1") - if n == 0: - return [] - idx = np.arange(n, dtype=np.int64) - return [idx[i:i + size] for i in range(0, n, size)] - - def _build_records_for_index( - subset_idx: int, - *, - age_bins_days_local: Sequence[float], - rng_local: np.random.Generator, - ) -> List[EvalRecord]: - event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)] - codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False) - times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False) - - doa_pos = np.flatnonzero(codes_ins == 1) - if doa_pos.size == 0: - raise ValueError("Expected DOA token (code=1) in event sequence") - doa_days = float(times_ins[int(doa_pos[0])]) - - is_disease = codes_ins >= N_TECH_TOKENS - - # Lifetime (ever) disease history for Clean Control filtering. - if np.any(is_disease): - lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype( - np.int64, copy=False - ) - lifetime_causes = np.unique(lifetime_causes) - else: - lifetime_causes = np.zeros((0,), dtype=np.int64) - - disease_pos_all = np.flatnonzero(is_disease) - disease_times_all = ( - times_ins[disease_pos_all] - if disease_pos_all.size > 0 - else np.zeros((0,), dtype=np.float64) - ) - - eps = 1e-6 - out: List[EvalRecord] = [] - for b in range(len(age_bins_days_local) - 1): - lo = float(age_bins_days_local[b]) - hi = float(age_bins_days_local[b + 1]) - - # Inclusion rule: - # 1) DOA <= bin_upper - if not (doa_days <= hi): - continue - - # 2) at least one disease event within bin, and baseline must satisfy t0>=DOA. - # Random Single-Point Sampling: choose exactly one valid event *index* per (patient, age_bin). - if disease_pos_all.size == 0: - continue - - in_bin = ( - (disease_times_all >= lo) - & (disease_times_all < hi) - & (disease_times_all >= doa_days) - ) - cand_pos = disease_pos_all[in_bin] - if cand_pos.size == 0: - continue - - cutoff_pos = int(rng_local.choice(cand_pos)) - t0_days = float(times_ins[cutoff_pos]) - - # Future disease events strictly after t0 - future_mask = (times_ins > (t0_days + eps)) & is_disease - future_pos = np.flatnonzero(future_mask) - if future_pos.size == 0: - next_cause = None - next_dt_years = None - future_causes = np.zeros((0,), dtype=np.int64) - future_dt_years_arr = np.zeros((0,), dtype=np.float32) - else: - future_times_days = times_ins[future_pos] - future_tokens = codes_ins[future_pos] - future_causes = ( - future_tokens - N_TECH_TOKENS).astype(np.int64) - future_dt_years_arr = ( - (future_times_days - t0_days) / DAYS_PER_YEAR - ).astype(np.float32) - - # next-event = minimal time > t0 (tie broken by earliest position) - next_idx = int(np.argmin(future_times_days)) - next_cause = int(future_causes[next_idx]) - next_dt_years = float(future_dt_years_arr[next_idx]) - - out.append( - EvalRecord( - subset_idx=int(subset_idx), - doa_days=float(doa_days), - t0_days=float(t0_days), - cutoff_pos=int(cutoff_pos), - next_event_cause=next_cause, - next_event_dt_years=next_dt_years, - lifetime_causes=lifetime_causes, - future_causes=future_causes, - future_dt_years=future_dt_years_arr, - ) - ) - return out - - def _process_chunk( - chunk_indices: Sequence[int], - *, - age_bins_days_local: Sequence[float], - seed_local: int, - ) -> List[EvalRecord]: - out: List[EvalRecord] = [] - for subset_idx in chunk_indices: - # Ensure each subject has its own deterministic RNG stream, so parallel - # workers do not share identical seeds. - ss = np.random.SeedSequence([int(seed_local), int(subset_idx)]) - rng_local = np.random.default_rng(ss) - out.extend( - _build_records_for_index( - int(subset_idx), - age_bins_days_local=age_bins_days_local, - rng_local=rng_local, - ) - ) - return out - - n = int(len(subset)) - chunks = _iter_chunks(n, int(chunk_size)) - - do_parallel = ( - int(n_jobs) != 1 - and Parallel is not None - and delayed is not None - and n > 0 - ) - - if do_parallel: - # Note: on Windows, process-based parallelism may require the underlying - # dataset to be pickleable. `prefer="threads"` is the default for safety. - parts = Parallel(n_jobs=int(n_jobs), prefer=str(prefer), batch_size=1)( - delayed(_process_chunk)( - chunk, - age_bins_days_local=age_bins_days, - seed_local=int(seed), - ) - for chunk in chunks - ) - records = [r for part in parts for r in part] - return records - - # Sequential (preserve prior behavior/progress reporting) - rng = np.random.default_rng(seed) - records: List[EvalRecord] = [] - eps = 1e-6 - for subset_idx in _progress( - range(len(subset)), - enabled=show_progress, - desc="Building eval records", - total=len(subset), - ): - event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)] - codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False) - times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False) - - doa_pos = np.flatnonzero(codes_ins == 1) - if doa_pos.size == 0: - raise ValueError("Expected DOA token (code=1) in event sequence") - doa_days = float(times_ins[int(doa_pos[0])]) - - is_disease = codes_ins >= N_TECH_TOKENS - - if np.any(is_disease): - lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype( - np.int64, copy=False - ) - lifetime_causes = np.unique(lifetime_causes) - else: - lifetime_causes = np.zeros((0,), dtype=np.int64) - - disease_pos_all = np.flatnonzero(is_disease) - disease_times_all = ( - times_ins[disease_pos_all] - if disease_pos_all.size > 0 - else np.zeros((0,), dtype=np.float64) - ) - - for b in range(len(age_bins_days) - 1): - lo = age_bins_days[b] - hi = age_bins_days[b + 1] - if not (doa_days <= hi): - continue - if disease_pos_all.size == 0: - continue - - in_bin = ( - (disease_times_all >= lo) - & (disease_times_all < hi) - & (disease_times_all >= doa_days) - ) - cand_pos = disease_pos_all[in_bin] - if cand_pos.size == 0: - continue - - cutoff_pos = int(rng.choice(cand_pos)) - t0_days = float(times_ins[cutoff_pos]) - - future_mask = (times_ins > (t0_days + eps)) & is_disease - future_pos = np.flatnonzero(future_mask) - if future_pos.size == 0: - next_cause = None - next_dt_years = None - future_causes = np.zeros((0,), dtype=np.int64) - future_dt_years_arr = np.zeros((0,), dtype=np.float32) - else: - future_times_days = times_ins[future_pos] - future_tokens = codes_ins[future_pos] - future_causes = ( - future_tokens - N_TECH_TOKENS).astype(np.int64) - future_dt_years_arr = ( - (future_times_days - t0_days) / DAYS_PER_YEAR - ).astype(np.float32) - - next_idx = int(np.argmin(future_times_days)) - next_cause = int(future_causes[next_idx]) - next_dt_years = float(future_dt_years_arr[next_idx]) - - records.append( - EvalRecord( - subset_idx=int(subset_idx), - doa_days=float(doa_days), - t0_days=float(t0_days), - cutoff_pos=int(cutoff_pos), - next_event_cause=next_cause, - next_event_dt_years=next_dt_years, - lifetime_causes=lifetime_causes, - future_causes=future_causes, - future_dt_years=future_dt_years_arr, - ) - ) - - return records - - -class EvalRecordDataset(Dataset): - def __init__(self, subset: Dataset, records: Sequence[EvalRecord]): - self.subset = subset - self.records = list(records) - self._cache: Dict[int, Tuple[torch.Tensor, - torch.Tensor, torch.Tensor, torch.Tensor, int]] = {} - self._cache_order: List[int] = [] - self._cache_max = 2048 - - def __len__(self) -> int: - return len(self.records) - - def __getitem__(self, idx: int): - rec = self.records[idx] - cached = self._cache.get(rec.subset_idx) - if cached is None: - event_seq, time_seq, cont, cate, sex = self.subset[rec.subset_idx] - cached = (event_seq, time_seq, cont, cate, int(sex)) - self._cache[rec.subset_idx] = cached - self._cache_order.append(rec.subset_idx) - if len(self._cache_order) > self._cache_max: - drop = self._cache_order.pop(0) - self._cache.pop(drop, None) - else: - event_seq, time_seq, cont, cate, sex = cached - cutoff = rec.cutoff_pos + 1 - event_seq = event_seq[:cutoff] - time_seq = time_seq[:cutoff] - baseline_pos = rec.cutoff_pos # same index in truncated sequence - return event_seq, time_seq, cont, cate, sex, baseline_pos - - -def eval_collate_fn(batch): - from torch.nn.utils.rnn import pad_sequence - - event_seqs, time_seqs, cont_feats, cate_feats, sexes, baseline_pos = zip( - *batch) - event_batch = pad_sequence(event_seqs, batch_first=True, padding_value=0) - time_batch = pad_sequence( - time_seqs, batch_first=True, padding_value=36525.0) - cont_batch = torch.stack(cont_feats, dim=0).unsqueeze(1) - cate_batch = torch.stack(cate_feats, dim=0).unsqueeze(1) - sex_batch = torch.tensor(sexes, dtype=torch.long) - baseline_pos = torch.tensor(baseline_pos, dtype=torch.long) - return event_batch, time_batch, cont_batch, cate_batch, sex_batch, baseline_pos - - -# ------------------------- -# Inference utilities -# ------------------------- - -def predict_cifs( - model: torch.nn.Module, - head: torch.nn.Module, - criterion: torch.nn.Module, - loader: DataLoader, - taus_years: Sequence[float], - device: torch.device, - show_progress: bool = False, - progress_desc: str = "Inference", -) -> np.ndarray: - model.eval() - head.eval() - - taus_t = torch.tensor(list(taus_years), dtype=torch.float32, device=device) - - all_out: List[np.ndarray] = [] - with torch.no_grad(): - for batch in _progress( - loader, - enabled=show_progress, - desc=progress_desc, - total=len(loader) if hasattr(loader, "__len__") else None, - ): - event_seq, time_seq, cont, cate, sex, baseline_pos = batch - event_seq = event_seq.to(device, non_blocking=True) - time_seq = time_seq.to(device, non_blocking=True) - cont = cont.to(device, non_blocking=True) - cate = cate.to(device, non_blocking=True) - sex = sex.to(device, non_blocking=True) - baseline_pos = baseline_pos.to(device, non_blocking=True) - - h = model(event_seq, time_seq, sex, cont, cate) - b_idx = torch.arange(h.size(0), device=device) - c = h[b_idx, baseline_pos] - logits = head(c) - - cifs = criterion.calculate_cifs(logits, taus_t) - out = cifs.detach().cpu().numpy() - all_out.append(out) - - return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,)) - - -def flatten_future_events( - records: Sequence[EvalRecord], - n_causes: int, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Flatten (record_idx, cause, dt_years) across all future events. - - Used to build horizon labels via vectorized masking + scatter. - """ - rec_idx_parts: List[np.ndarray] = [] - cause_parts: List[np.ndarray] = [] - dt_parts: List[np.ndarray] = [] - - for i, r in enumerate(records): - if r.future_causes.size == 0: - continue - causes = r.future_causes - dts = r.future_dt_years - # Keep only valid cause ids. - m = (causes >= 0) & (causes < n_causes) - if not np.any(m): - continue - causes = causes[m].astype(np.int64, copy=False) - dts = dts[m].astype(np.float32, copy=False) - rec_idx_parts.append(np.full((causes.size,), i, dtype=np.int32)) - cause_parts.append(causes) - dt_parts.append(dts) - - if not rec_idx_parts: - return ( - np.zeros((0,), dtype=np.int32), - np.zeros((0,), dtype=np.int64), - np.zeros((0,), dtype=np.float32), - ) - - return ( - np.concatenate(rec_idx_parts, axis=0), - np.concatenate(cause_parts, axis=0), - np.concatenate(dt_parts, axis=0), - ) - - -# ------------------------- -# Metrics helpers -# ------------------------- - -def roc_auc_ovr(y_true: np.ndarray, y_score: np.ndarray) -> float: - """Binary ROC AUC with tie-aware average ranks. - - Returns NaN if y_true has no positives or no negatives. - """ - y_true = np.asarray(y_true).astype(np.int32) - y_score = np.asarray(y_score).astype(np.float64) - - n_pos = int(y_true.sum()) - n = int(y_true.size) - n_neg = n - n_pos - if n_pos == 0 or n_neg == 0: - return float("nan") - - order = np.argsort(y_score, kind="mergesort") - scores_sorted = y_score[order] - y_sorted = y_true[order] - - ranks = np.empty(n, dtype=np.float64) - i = 0 - while i < n: - j = i + 1 - while j < n and scores_sorted[j] == scores_sorted[i]: - j += 1 - # average rank for ties, ranks are 1..n - avg_rank = 0.5 * (i + 1 + j) - ranks[i:j] = avg_rank - i = j - - sum_ranks_pos = float((ranks * y_sorted).sum()) - auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2.0) / (n_pos * n_neg) - return float(auc) - - -def topk_indices(scores: np.ndarray, k: int) -> np.ndarray: - """Return indices of top-k scores per row (descending).""" - if k <= 0: - raise ValueError("k must be positive") - n, K = scores.shape - k = min(k, K) - # argpartition gives arbitrary order within topk; sort those by score - part = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k] - part_scores = np.take_along_axis(scores, part, axis=1) - order = np.argsort(-part_scores, axis=1, kind="mergesort") - return np.take_along_axis(part, order, axis=1) - - -# ------------------------- -# Statistical evaluation (DeLong) -# ------------------------- - -def compute_midrank(x: np.ndarray) -> np.ndarray: - """Compute midranks of a 1D array (1-based ranks, tie-aware).""" - x = np.asarray(x, dtype=np.float64) - if x.ndim != 1: - raise ValueError("compute_midrank expects a 1D array") - - order = np.argsort(x, kind="mergesort") - x_sorted = x[order] - n = int(x_sorted.size) - - midranks = np.empty((n,), dtype=np.float64) - i = 0 - while i < n: - j = i - while j < n and x_sorted[j] == x_sorted[i]: - j += 1 - # ranks are 1..n; average over ties - mid = 0.5 * ((i + 1) + j) - midranks[i:j] = mid - i = j - - out = np.empty((n,), dtype=np.float64) - out[order] = midranks - return out - - -def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]: - """Fast DeLong method for AUC covariance. - - Args: - predictions_sorted_transposed: shape (n_classifiers, n_examples), where the first - label_1_count examples are positives. - label_1_count: number of positive examples. - Returns: - (aucs, delong_cov) - """ - preds = np.asarray(predictions_sorted_transposed, dtype=np.float64) - if preds.ndim != 2: - raise ValueError("predictions_sorted_transposed must be 2D") - - m = int(label_1_count) - n = int(preds.shape[1] - m) - if m <= 0 or n <= 0: - raise ValueError("DeLong requires at least 1 positive and 1 negative") - - k = int(preds.shape[0]) - tx = np.empty((k, m), dtype=np.float64) - ty = np.empty((k, n), dtype=np.float64) - tz = np.empty((k, m + n), dtype=np.float64) - - for r in range(k): - tx[r] = compute_midrank(preds[r, :m]) - ty[r] = compute_midrank(preds[r, m:]) - tz[r] = compute_midrank(preds[r, :]) - - aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n) - - v01 = (tz[:, :m] - tx) / float(n) - v10 = 1.0 - (tz[:, m:] - ty) / float(m) - - # np.cov expects variables in rows by default when rowvar=True. - sx = np.cov(v01, rowvar=True, bias=False) - sy = np.cov(v10, rowvar=True, bias=False) - delong_cov = sx / float(m) + sy / float(n) - return aucs, delong_cov - - -def compute_ground_truth_statistics(ground_truth: np.ndarray) -> Tuple[np.ndarray, int]: - """Return ordering that places positives first and label_1_count.""" - y = np.asarray(ground_truth, dtype=np.int32) - if y.ndim != 1: - raise ValueError("ground_truth must be 1D") - label_1_count = int(y.sum()) - order = np.argsort(-y, kind="mergesort") - return order, label_1_count - - -def get_auc_delong_var(healthy_scores: np.ndarray, diseased_scores: np.ndarray) -> Tuple[float, float]: - """Compute AUC and its DeLong variance. - - Args: - healthy_scores: scores for controls (label=0) - diseased_scores: scores for cases (label=1) - Returns: - (auc, auc_variance) - """ - h = np.asarray(healthy_scores, dtype=np.float64).reshape(-1) - d = np.asarray(diseased_scores, dtype=np.float64).reshape(-1) - n0 = int(h.size) - n1 = int(d.size) - if n0 == 0 or n1 == 0: - return float("nan"), float("nan") - - # Arrange positives first as required by fastDeLong. - scores = np.concatenate([d, h], axis=0) - gt = np.concatenate([ - np.ones((n1,), dtype=np.int32), - np.zeros((n0,), dtype=np.int32), - ]) - order, label_1_count = compute_ground_truth_statistics(gt) - preds_sorted = scores[order][None, :] - aucs, cov = fastDeLong(preds_sorted, label_1_count) - auc = float(aucs[0]) - cov = np.asarray(cov) - var = float(cov[0, 0]) if cov.ndim == 2 else float(cov) - return auc, var - - -# ------------------------- -# Next-token inference helper -# ------------------------- - -def predict_next_token_logits( - model: torch.nn.Module, - head: torch.nn.Module, - loader: DataLoader, - device: torch.device, - show_progress: bool = False, - progress_desc: str = "Inference (next-token)", - return_probs: bool = True, -) -> np.ndarray: - """Predict per-cause next-token scores at baseline positions. - - Returns: - np.ndarray of shape (N, K) where K is number of diseases (causes). - - Notes: - - For loss types with time/bin dimensions (e.g., discrete-time CIF), this uses the - *first* time/bin (index 0) and drops the complement channel when present. - - If return_probs=True, applies softmax over causes for probability-like scores. - """ - model.eval() - head.eval() - - all_out: List[np.ndarray] = [] - with torch.no_grad(): - for batch in _progress( - loader, - enabled=show_progress, - desc=progress_desc, - total=len(loader) if hasattr(loader, "__len__") else None, - ): - event_seq, time_seq, cont, cate, sex, baseline_pos = batch - event_seq = event_seq.to(device, non_blocking=True) - time_seq = time_seq.to(device, non_blocking=True) - cont = cont.to(device, non_blocking=True) - cate = cate.to(device, non_blocking=True) - sex = sex.to(device, non_blocking=True) - baseline_pos = baseline_pos.to(device, non_blocking=True) - - h = model(event_seq, time_seq, sex, cont, cate) - b_idx = torch.arange(h.size(0), device=device) - c = h[b_idx, baseline_pos] - logits = head(c) - - # logits can be (B, K) or (B, K, T) or (B, K+1, T) - if logits.ndim == 2: - cause_logits = logits - elif logits.ndim == 3: - # Use the first time/bin. - cause_logits = logits[..., 0] - else: - raise ValueError( - f"Unsupported logits shape for next-token inference: {tuple(logits.shape)}" - ) - - # If a complement/survival channel exists (discrete-time CIF), drop it. - if hasattr(model, "n_disease"): - n_disease = int(getattr(model, "n_disease")) - if cause_logits.size(1) == n_disease + 1: - cause_logits = cause_logits[:, :n_disease] - elif cause_logits.size(1) > n_disease: - cause_logits = cause_logits[:, :n_disease] - - if return_probs: - scores = torch.softmax(cause_logits, dim=1) - else: - scores = cause_logits - - all_out.append(scores.detach().cpu().numpy()) - - return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,))