Remove evaluate_next_event.py and utils.py files to streamline the codebase. These files contained functions and classes related to evaluation and utility operations that are no longer needed.

This commit is contained in:
2026-01-18 15:41:07 +08:00
parent 1d682c130d
commit b80d9a4256
3 changed files with 0 additions and 1691 deletions

View File

@@ -1,455 +0,0 @@
"""Horizon-capture evaluation (event-driven, age-stratified).
This script implements the protocol described in 评估方案.md:
- Age-stratified evaluation: metrics are computed independently within each age bin (no mixing).
- Event-driven inclusion: each (person, age_bin) yields a record iff DOA <= bin upper bound and
there is at least one disease event in the bin; baseline t0 is sampled randomly from in-bin
disease events with t0 >= DOA.
- No follow-up completeness filtering (no t0+tau <= t_end constraint).
Primary outputs per age bin:
- Top-K Event Capture@tau (event-count based)
- WorkloadYield curves (Top-p% people by a person-level horizon score)
Secondary (diagnostic-only) outputs per age bin:
- Approximate event-driven AUC / Brier (no IPCW, no censoring adjustment)
"""
import argparse
import math
import os
from typing import Dict, List, Sequence, Tuple
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
try:
from tqdm import tqdm
except Exception: # pragma: no cover
tqdm = None
from utils import (
EvalRecordDataset,
build_dataset_from_config,
build_event_driven_records,
build_model_head_criterion,
eval_collate_fn,
flatten_future_events,
get_test_subset,
load_checkpoint_into,
load_train_config,
make_inference_dataloader_kwargs,
parse_float_list,
predict_cifs,
roc_auc_ovr,
seed_everything,
topk_indices,
DAYS_PER_YEAR,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Evaluate horizon-capture using CIF at horizons")
p.add_argument("--run_dir", type=str, required=True)
p.add_argument(
"--horizons",
type=str,
nargs="+",
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"],
help="Horizon grid in years",
)
p.add_argument(
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
)
p.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else "cpu"),
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument(
"--max_cpu_cores",
type=int,
default=-1,
help="Maximum number of CPU cores to use for parallel data construction.",
)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--min_pos", type=int, default=20)
p.add_argument(
"--topk_list",
type=int,
nargs="+",
default=[5, 10, 20, 50],
)
p.add_argument(
"--workload_fracs",
type=float,
nargs="+",
default=[0.01, 0.02, 0.05, 0.1, 0.2, 0.5],
help="Fractions for workloadyield curves (Top-p%% people).",
)
p.add_argument(
"--no_tqdm",
action="store_true",
help="Disable tqdm progress bars",
)
return p.parse_args()
def _format_age_bin_label(lo: float, hi: float) -> str:
if np.isinf(hi):
return f"[{lo}, inf)"
return f"[{lo}, {hi})"
def _assign_age_bin_idx(t0_days: np.ndarray, age_bins_years: Sequence[float]) -> np.ndarray:
age_bins_years = np.asarray(list(age_bins_years), dtype=np.float64)
age_bins_days = age_bins_years * DAYS_PER_YEAR
return np.searchsorted(age_bins_days, t0_days, side="left") - 1
def _event_counts_within_tau(
n_records: int,
event_record_idx: np.ndarray,
event_dt_years: np.ndarray,
tau_years: float,
) -> np.ndarray:
"""Count events within (t0, t0+tau] per record (event-count, not unique causes)."""
if event_record_idx.size == 0:
return np.zeros((n_records,), dtype=np.int64)
m = event_dt_years <= float(tau_years)
if not np.any(m):
return np.zeros((n_records,), dtype=np.int64)
return np.bincount(event_record_idx[m], minlength=n_records).astype(np.int64)
def build_labels_within_tau_flat(
n_records: int,
n_causes: int,
event_record_idx: np.ndarray,
event_cause: np.ndarray,
event_dt_years: np.ndarray,
tau_years: float,
) -> np.ndarray:
"""Build y_within_tau using a flattened (record,cause,dt) representation.
This preserves the exact label definition: y[i,k]=1 iff at least one event of cause k
occurs in (t0, t0+tau].
"""
y = np.zeros((n_records, n_causes), dtype=np.int8)
if event_dt_years.size == 0:
return y
m = event_dt_years <= float(tau_years)
if not np.any(m):
return y
y[event_record_idx[m], event_cause[m]] = 1
return y
def main() -> None:
args = parse_args()
seed_everything(args.seed)
show_progress = (not args.no_tqdm)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
dataset = build_dataset_from_config(cfg)
test_subset = get_test_subset(dataset, cfg)
age_bins_years = parse_float_list(args.age_bins)
horizons = parse_float_list(args.horizons)
horizons = [float(h) for h in horizons]
records = build_event_driven_records(
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
show_progress=show_progress,
n_jobs=int(args.max_cpu_cores),
)
device = torch.device(args.device)
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(test_subset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader(
rec_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=eval_collate_fn,
**dl_kwargs,
)
# Print disclaimers every run (requested)
print("PRIMARY METRICS: event-count Capture@K and WorkloadYield, computed independently per age bin.")
print("DIAGNOSTICS ONLY: AUC/Brier below are event-driven approximations (no IPCW / censoring adjustment).")
scores = predict_cifs(
model,
head,
criterion,
loader,
horizons,
device=device,
show_progress=show_progress,
progress_desc="Inference (horizons)",
)
# scores shape: (N, K, H)
if scores.ndim != 3:
raise ValueError(
f"Expected CIF scores with shape (N,K,H), got {scores.shape}")
N, K, H = scores.shape
if N != len(records):
raise ValueError("Record count mismatch")
# Pre-flatten all future events once to avoid repeated per-record scans.
# NOTE: these are event-level arrays (not unique causes), suitable for event-count Capture@K.
evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K)
# Assign each record to an age bin (based on t0; by construction t0 is within the bin).
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
bin_idx = _assign_age_bin_idx(t0_days, age_bins_years)
age_bins_years_arr = np.asarray(list(age_bins_years), dtype=np.float64)
capture_rows: List[Dict[str, object]] = []
workload_rows: List[Dict[str, object]] = []
# Diagnostics (optional): approximate event-driven AUC/Brier computed per bin.
diag_rows: List[Dict[str, object]] = []
diag_per_cause_parts: List[pd.DataFrame] = []
bins_iter = range(len(age_bins_years_arr) - 1)
if show_progress and tqdm is not None:
bins_iter = tqdm(bins_iter, total=len(
age_bins_years_arr) - 1, desc="Age bins")
for b in bins_iter:
lo = float(age_bins_years_arr[b])
hi = float(age_bins_years_arr[b + 1])
age_label = _format_age_bin_label(lo, hi)
m_rec = bin_idx == b
n_bin = int(m_rec.sum())
if n_bin == 0:
continue
rec_idx_bin = np.flatnonzero(m_rec).astype(np.int32)
# Filter events to this bin's records once.
m_evt_bin = m_rec[evt_rec_idx] if evt_rec_idx.size > 0 else np.zeros(
(0,), dtype=bool)
evt_rec_idx_b = evt_rec_idx[m_evt_bin]
evt_cause_b = evt_cause[m_evt_bin]
evt_dt_b = evt_dt[m_evt_bin]
horizon_iter = enumerate(horizons)
if show_progress and tqdm is not None:
horizon_iter = tqdm(horizon_iter, total=len(
horizons), desc=f"Horizons {age_label}")
# Precompute a local index mapping for diagnostics label building.
local_map = np.full((N,), -1, dtype=np.int32)
local_map[rec_idx_bin] = np.arange(n_bin, dtype=np.int32)
for h_idx, tau in horizon_iter:
s_tau_all = scores[:, :, h_idx]
s_tau = s_tau_all[m_rec]
# -------------------------
# Primary metric: Top-K Event Capture@tau (event-count based)
# -------------------------
denom_events = int(np.sum(evt_dt_b <= float(tau))
) if evt_dt_b.size > 0 else 0
if denom_events == 0:
for topk in args.topk_list:
capture_rows.append(
{
"age_bin": age_label,
"tau_years": float(tau),
"topk": int(topk),
"capture_at_k": float("nan"),
"denom_events": int(0),
"numer_events": int(0),
"n_records": int(n_bin),
"n_causes": int(K),
}
)
else:
m_evt_tau = evt_dt_b <= float(tau)
evt_rec_idx_tau = evt_rec_idx_b[m_evt_tau]
evt_cause_tau = evt_cause_b[m_evt_tau]
# For each K, compute whether each event's cause is in that record's Top-K list.
for topk in args.topk_list:
topk = int(topk)
idx = topk_indices(s_tau_all, topk) # shape (N, topk)
idx_for_events = idx[evt_rec_idx_tau]
hits = (idx_for_events ==
evt_cause_tau[:, None]).any(axis=1)
numer_events = int(hits.sum())
capture = float(numer_events / denom_events)
capture_rows.append(
{
"age_bin": age_label,
"tau_years": float(tau),
"topk": int(topk),
"capture_at_k": capture,
"denom_events": int(denom_events),
"numer_events": int(numer_events),
"n_records": int(n_bin),
"n_causes": int(K),
}
)
# -------------------------
# Primary metric: WorkloadYield (Top-p% people)
# -------------------------
# Person-level score: max_k CIF_k(tau). This is used only for workloadyield ranking.
person_score = s_tau.max(axis=1) if K > 0 else np.zeros(
(n_bin,), dtype=np.float64)
order = np.argsort(-person_score, kind="mergesort")
counts_per_record = _event_counts_within_tau(
n_bin, local_map[evt_rec_idx_b], evt_dt_b, tau)
total_events = int(counts_per_record.sum())
overall_events_per_person = (
total_events / float(n_bin)) if n_bin > 0 else float("nan")
for frac in args.workload_fracs:
frac = float(frac)
if frac <= 0.0:
continue
n_sel = int(math.ceil(frac * n_bin))
n_sel = min(max(n_sel, 1), n_bin)
sel_local = order[:n_sel]
events_captured = int(counts_per_record[sel_local].sum())
capture_rate = float(
events_captured / total_events) if total_events > 0 else float("nan")
selected_events_per_person = (
events_captured / float(n_sel)) if n_sel > 0 else float("nan")
lift = (selected_events_per_person /
overall_events_per_person) if overall_events_per_person > 0 else float("nan")
workload_rows.append(
{
"age_bin": age_label,
"tau_years": float(tau),
"frac_selected": float(frac),
"n_selected": int(n_sel),
"n_records": int(n_bin),
"total_events": int(total_events),
"events_captured": int(events_captured),
"capture_rate": capture_rate,
"lift_events_per_person": float(lift),
"person_score_def": "max_k_CIF_k(tau)",
}
)
# -------------------------
# Diagnostics (optional): approximate event-driven AUC/Brier
# -------------------------
# Convert event-level data to binary labels y[i,k]=1 iff >=1 event of cause k within tau.
y_tau_bin = np.zeros((n_bin, K), dtype=np.int8)
if evt_dt_b.size > 0:
m_evt_tau = evt_dt_b <= float(tau)
if np.any(m_evt_tau):
rec_local = local_map[evt_rec_idx_b[m_evt_tau]]
valid = rec_local >= 0
y_tau_bin[rec_local[valid],
evt_cause_b[m_evt_tau][valid]] = 1
n_pos = y_tau_bin.sum(axis=0).astype(np.int64)
n_neg = (int(n_bin) - n_pos).astype(np.int64)
brier_per_cause = np.mean(
(y_tau_bin.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0
)
brier_macro = float(np.mean(brier_per_cause)
) if K > 0 else float("nan")
brier_weighted = float(np.sum(
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
auc = np.full((K,), np.nan, dtype=np.float64)
min_pos = int(args.min_pos)
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
for k in candidates:
auc[k] = roc_auc_ovr(y_tau_bin[:, k].astype(
np.int32), s_tau[:, k].astype(np.float64))
finite_auc = auc[np.isfinite(auc)]
auc_macro = float(np.mean(finite_auc)
) if finite_auc.size > 0 else float("nan")
w_mask = np.isfinite(auc)
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
n_valid_auc = int(np.isfinite(auc).sum())
diag_rows.append(
{
"age_bin": age_label,
"tau_years": float(tau),
"n_records": int(n_bin),
"n_causes": int(K),
"auc_macro": auc_macro,
"auc_weighted_by_npos": auc_weighted,
"n_causes_valid_auc": int(n_valid_auc),
"brier_macro": brier_macro,
"brier_weighted_by_npos": brier_weighted,
}
)
diag_per_cause_parts.append(
pd.DataFrame(
{
"age_bin": age_label,
"tau_years": float(tau),
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
"n_neg": n_neg,
"auc": auc,
"brier": brier_per_cause,
}
)
)
out_capture = os.path.join(run_dir, "horizon_capture.csv")
out_wy = os.path.join(run_dir, "workload_yield.csv")
out_diag = os.path.join(run_dir, "horizon_metrics.csv")
out_diag_pc = os.path.join(run_dir, "horizon_per_cause.csv")
pd.DataFrame(capture_rows).to_csv(out_capture, index=False)
pd.DataFrame(workload_rows).to_csv(out_wy, index=False)
pd.DataFrame(diag_rows).to_csv(out_diag, index=False)
if diag_per_cause_parts:
pd.concat(diag_per_cause_parts, ignore_index=True).to_csv(
out_diag_pc, index=False)
else:
pd.DataFrame(columns=["age_bin", "tau_years", "cause_id", "n_pos",
"n_neg", "auc", "brier"]).to_csv(out_diag_pc, index=False)
print(f"Wrote {out_capture}")
print(f"Wrote {out_wy}")
print(f"Wrote {out_diag} (diagnostic-only)")
print(f"Wrote {out_diag_pc} (diagnostic-only)")
if __name__ == "__main__":
main()

View File

@@ -1,311 +0,0 @@
import argparse
import os
from typing import List
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
try:
from tqdm import tqdm # noqa: F401
except Exception: # pragma: no cover
tqdm = None
from utils import (
EvalRecordDataset,
build_dataset_from_config,
build_event_driven_records,
build_model_head_criterion,
eval_collate_fn,
get_test_subset,
make_inference_dataloader_kwargs,
load_checkpoint_into,
load_train_config,
parse_float_list,
predict_next_token_logits,
get_auc_delong_var,
seed_everything,
DAYS_PER_YEAR,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Evaluate next-event prediction using next-token scores"
)
p.add_argument("--run_dir", type=str, required=True)
p.add_argument(
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
)
p.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else "cpu"),
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument(
"--max_cpu_cores",
type=int,
default=-1,
help="Maximum number of CPU cores to use for parallel data construction.",
)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--min_pos",
type=int,
default=20,
help="Minimum positives for per-cause AUC",
)
p.add_argument(
"--no_tqdm",
action="store_true",
help="Disable tqdm progress bars",
)
return p.parse_args()
def _format_age_bin_label(lo: float, hi: float) -> str:
if np.isinf(hi):
return f"[{lo}, inf)"
return f"[{lo}, {hi})"
def _compute_next_event_auc_clean_control(
*,
scores: np.ndarray,
records: list,
) -> pd.DataFrame:
"""Delphi-2M next-event AUC (clean control) per cause.
Definitions per cause k:
- Case: next_event_cause == k
- Control (clean): next_event_cause != k AND k not in record.lifetime_causes
AUC is computed with DeLong variance.
"""
n_records = int(len(records))
if n_records == 0:
return pd.DataFrame(
columns=["cause_id", "n_case", "n_control", "auc", "auc_variance"],
)
K = int(scores.shape[1])
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
dtype=np.int64,
)
# Pre-compute lifetime disease membership matrix for vectorized clean-control filtering.
# lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes.
# Use a sparse matrix when SciPy is available to keep memory bounded for large K.
row_parts: List[np.ndarray] = []
col_parts: List[np.ndarray] = []
for i, r in enumerate(records):
causes = getattr(r, "lifetime_causes", None)
if causes is None:
continue
causes = np.asarray(causes, dtype=np.int64)
if causes.size == 0:
continue
# Keep only valid cause ids.
m_valid = (causes >= 0) & (causes < K)
if not np.any(m_valid):
continue
causes = causes[m_valid]
row_parts.append(np.full((causes.size,), i, dtype=np.int32))
col_parts.append(causes.astype(np.int32, copy=False))
try:
import scipy.sparse as sp # type: ignore
if row_parts:
rows = np.concatenate(row_parts, axis=0)
cols = np.concatenate(col_parts, axis=0)
data = np.ones((rows.size,), dtype=bool)
lifetime_matrix = sp.csc_matrix(
(data, (rows, cols)), shape=(n_records, K))
else:
lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool)
lifetime_is_sparse = True
except Exception: # pragma: no cover
lifetime_matrix = np.zeros((n_records, K), dtype=bool)
for rows, cols in zip(row_parts, col_parts):
lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype(
np.int64, copy=False)] = True
lifetime_is_sparse = False
auc = np.full((K,), np.nan, dtype=np.float64)
var = np.full((K,), np.nan, dtype=np.float64)
n_case = np.zeros((K,), dtype=np.int64)
n_control = np.zeros((K,), dtype=np.int64)
for k in range(K):
case_mask = y_next == k
if not np.any(case_mask):
continue
# Clean controls: not next-event k AND never had k in their lifetime history.
control_mask = y_next != k
if np.any(control_mask):
if lifetime_is_sparse:
had_k = np.asarray(lifetime_matrix.getcol(
k).toarray().ravel(), dtype=bool)
else:
had_k = lifetime_matrix[:, k]
is_clean = ~had_k
control_mask = control_mask & is_clean
cs = scores[case_mask, k]
hs = scores[control_mask, k]
n_case[k] = int(cs.size)
n_control[k] = int(hs.size)
if cs.size == 0 or hs.size == 0:
continue
a, v = get_auc_delong_var(hs, cs)
auc[k] = float(a)
var[k] = float(v)
return pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_case": n_case,
"n_control": n_control,
"auc": auc,
"auc_variance": var,
}
)
def main() -> None:
args = parse_args()
# Best-effort control of implicit parallelism to avoid CPU oversubscription.
# Note: environment variables are ideally set before importing NumPy/PyTorch,
# but setting them early in main can still affect subprocesses or lazy readers.
if int(args.max_cpu_cores) > 0:
n_threads = int(args.max_cpu_cores)
torch.set_num_threads(n_threads)
for k in (
"OMP_NUM_THREADS",
"MKL_NUM_THREADS",
"OPENBLAS_NUM_THREADS",
"VECLIB_MAXIMUM_THREADS",
"NUMEXPR_NUM_THREADS",
):
os.environ[k] = str(n_threads)
print(f"Restricting implicit parallelism to {n_threads} threads.")
seed_everything(args.seed)
show_progress = (not args.no_tqdm)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
dataset = build_dataset_from_config(cfg)
test_subset = get_test_subset(dataset, cfg)
age_bins_years = parse_float_list(args.age_bins)
records = build_event_driven_records(
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
show_progress=show_progress,
n_jobs=int(args.max_cpu_cores),
)
device = torch.device(args.device)
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(test_subset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader(
rec_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=eval_collate_fn,
**dl_kwargs,
)
scores = predict_next_token_logits(
model,
head,
loader,
device=device,
show_progress=show_progress,
progress_desc="Inference (next-token)",
return_probs=True,
)
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
dtype=np.int64,
)
# Overall (preserve existing output files/shape)
# Strict protocol: evaluate independently per age bin (no mixing).
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
age_bins_days = age_bins_years * DAYS_PER_YEAR
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
per_bin_metric_rows: List[dict] = []
per_bin_auc_parts: List[pd.DataFrame] = []
for b in range(len(age_bins_years) - 1):
lo = float(age_bins_years[b])
hi = float(age_bins_years[b + 1])
label = _format_age_bin_label(lo, hi)
m = bin_idx == b
m_scores = scores[m]
m_records = [r for r, keep in zip(records, m) if bool(keep)]
# Coverage metric for transparency (not Delphi-2M AUC itself).
m_y = y_next[m]
n_total = int(m_y.size)
n_eligible = int((m_y >= 0).sum())
coverage = float(n_eligible / n_total) if n_total > 0 else 0.0
per_bin_metric_rows.append(
{"age_bin": label, "metric": "n_records_total", "value": n_total})
per_bin_metric_rows.append(
{"age_bin": label, "metric": "n_next_event_eligible", "value": n_eligible})
per_bin_metric_rows.append(
{"age_bin": label, "metric": "coverage", "value": coverage})
m_auc = _compute_next_event_auc_clean_control(
scores=m_scores,
records=m_records,
)
m_auc.insert(0, "age_bin", label)
per_bin_auc_parts.append(m_auc)
out_metrics_bins = os.path.join(
run_dir, "next_event_metrics_by_age_bin.csv")
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv")
if per_bin_auc_parts:
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
out_auc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "cause_id", "n_case", "n_control",
"auc", "auc_variance"]).to_csv(out_auc_bins, index=False)
print("PRIMARY METRICS: Per-cause AUC is reported per age bin using Delphi-2M clean controls.")
print("EVAL METHOD: DeLong AUC variance is reported (per cause).")
print(f"Wrote {out_metrics_bins}")
print(f"Wrote {out_auc_bins}")
if __name__ == "__main__":
main()

925
utils.py
View File

@@ -1,925 +0,0 @@
import json
import math
import os
import random
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, Subset, random_split
try:
from tqdm import tqdm as _tqdm
except Exception: # pragma: no cover
_tqdm = None
try:
from joblib import Parallel, delayed # type: ignore
except Exception: # pragma: no cover
Parallel = None
delayed = None
from dataset import HealthDataset
from losses import (
DiscreteTimeCIFNLLLoss,
ExponentialNLLLoss,
PiecewiseExponentialCIFNLLLoss,
)
from model import DelphiFork, SapDelphi, SimpleHead
DAYS_PER_YEAR = 365.25
N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2
def _progress(iterable, *, enabled: bool, desc: str, total: Optional[int] = None):
if enabled and _tqdm is not None:
return _tqdm(iterable, desc=desc, total=total)
return iterable
def make_inference_dataloader_kwargs(
device: torch.device,
num_workers: int,
) -> Dict[str, Any]:
"""DataLoader kwargs tuned for inference throughput.
Behavior/metrics are unchanged; this only impacts speed.
"""
use_cuda = device.type == "cuda" and torch.cuda.is_available()
kwargs: Dict[str, Any] = {
"pin_memory": bool(use_cuda),
}
if num_workers > 0:
kwargs["persistent_workers"] = True
# default prefetch is 2; set explicitly for clarity.
kwargs["prefetch_factor"] = 2
return kwargs
# -------------------------
# Config + determinism
# -------------------------
def _replace_nonstandard_json_numbers(text: str) -> str:
# Python's json.dump writes Infinity/-Infinity/NaN for non-finite floats.
# Replace bare tokens (not within quotes) with string placeholders.
def repl(match: re.Match[str]) -> str:
token = match.group(0)
if token == "-Infinity":
return '"__NINF__"'
if token == "Infinity":
return '"__INF__"'
if token == "NaN":
return '"__NAN__"'
return token
return re.sub(r'(?<![\w\."])(-Infinity|Infinity|NaN)(?![\w\."])', repl, text)
def _restore_placeholders(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: _restore_placeholders(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_restore_placeholders(v) for v in obj]
if obj == "__INF__":
return float("inf")
if obj == "__NINF__":
return float("-inf")
if obj == "__NAN__":
return float("nan")
return obj
def load_train_config(run_dir: str) -> Dict[str, Any]:
cfg_path = os.path.join(run_dir, "train_config.json")
with open(cfg_path, "r", encoding="utf-8") as f:
raw = f.read()
raw = _replace_nonstandard_json_numbers(raw)
cfg = json.loads(raw)
cfg = _restore_placeholders(cfg)
return cfg
def seed_everything(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def parse_float_list(values: Sequence[str]) -> List[float]:
out: List[float] = []
for v in values:
s = str(v).strip().lower()
if s in {"inf", "+inf", "infty", "infinity", "+infinity"}:
out.append(float("inf"))
elif s in {"-inf", "-infty", "-infinity"}:
out.append(float("-inf"))
else:
out.append(float(v))
return out
# -------------------------
# Dataset + split (match train.py)
# -------------------------
def build_dataset_from_config(cfg: Dict[str, Any]) -> HealthDataset:
data_prefix = cfg["data_prefix"]
full_cov = bool(cfg.get("full_cov", False))
if full_cov:
cov_list = None
else:
cov_list = ["bmi", "smoking", "alcohol"]
dataset = HealthDataset(
data_prefix=data_prefix,
covariate_list=cov_list,
)
return dataset
def get_test_subset(dataset: HealthDataset, cfg: Dict[str, Any]) -> Subset:
n_total = len(dataset)
train_ratio = float(cfg["train_ratio"])
val_ratio = float(cfg["val_ratio"])
seed = int(cfg["random_seed"])
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
n_test = n_total - n_train - n_val
_, _, test_subset = random_split(
dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(seed),
)
return test_subset
# -------------------------
# Model + head + criterion (match train.py)
# -------------------------
def build_model_head_criterion(
cfg: Dict[str, Any],
dataset: HealthDataset,
device: torch.device,
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
loss_type = cfg["loss_type"]
if loss_type == "exponential":
criterion = ExponentialNLLLoss(lambda_reg=float(
cfg.get("lambda_reg", 0.0))).to(device)
out_dims = [dataset.n_disease]
elif loss_type == "discrete_time_cif":
bin_edges = [float(x) for x in cfg["bin_edges"]]
criterion = DiscreteTimeCIFNLLLoss(
bin_edges=bin_edges,
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
).to(device)
out_dims = [dataset.n_disease + 1, len(bin_edges)]
elif loss_type == "pwe_cif":
# training drops +inf for PWE
raw_edges = [float(x) for x in cfg["bin_edges"]]
pwe_edges = [float(x) for x in raw_edges if math.isfinite(float(x))]
if len(pwe_edges) < 2:
raise ValueError(
"pwe_cif requires at least 2 finite bin edges (including 0). "
f"Got bin_edges={raw_edges}"
)
if float(pwe_edges[0]) != 0.0:
raise ValueError(
f"pwe_cif requires bin_edges[0]==0.0; got {pwe_edges[0]}")
criterion = PiecewiseExponentialCIFNLLLoss(
bin_edges=pwe_edges,
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
).to(device)
n_bins = len(pwe_edges) - 1
out_dims = [dataset.n_disease, n_bins]
else:
raise ValueError(f"Unsupported loss_type: {loss_type}")
model_type = cfg["model_type"]
if model_type == "delphi_fork":
model = DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=N_TECH_TOKENS,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
n_cont=int(dataset.n_cont),
n_cate=int(dataset.n_cate),
cate_dims=list(dataset.cate_dims),
).to(device)
elif model_type == "sap_delphi":
model = SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=N_TECH_TOKENS,
n_embd=int(cfg["n_embd"]),
n_head=int(cfg["n_head"]),
n_layer=int(cfg["n_layer"]),
pdrop=float(cfg.get("pdrop", 0.0)),
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
n_cont=int(dataset.n_cont),
n_cate=int(dataset.n_cate),
cate_dims=list(dataset.cate_dims),
pretrained_weights_path=str(
cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")),
freeze_embeddings=True,
).to(device)
else:
raise ValueError(f"Unsupported model_type: {model_type}")
head = SimpleHead(
n_embd=int(cfg["n_embd"]),
out_dims=list(out_dims),
).to(device)
return model, head, criterion
def load_checkpoint_into(
run_dir: str,
model: torch.nn.Module,
head: torch.nn.Module,
criterion: Optional[torch.nn.Module],
device: torch.device,
) -> Dict[str, Any]:
ckpt_path = os.path.join(run_dir, "best_model.pt")
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt["model_state_dict"], strict=True)
head.load_state_dict(ckpt["head_state_dict"], strict=True)
if criterion is not None and "criterion_state_dict" in ckpt:
try:
criterion.load_state_dict(
ckpt["criterion_state_dict"], strict=False)
except Exception:
# Criterion state is not essential for inference.
pass
return ckpt
# -------------------------
# Evaluation record construction (event-driven)
# -------------------------
@dataclass(frozen=True)
class EvalRecord:
subset_idx: int
doa_days: float
t0_days: float
cutoff_pos: int # baseline position (inclusive)
next_event_cause: Optional[int]
next_event_dt_years: Optional[float]
# (U,) unique causes ever observed (clean-control filtering)
lifetime_causes: np.ndarray
future_causes: np.ndarray # (E,) in [0..K-1]
future_dt_years: np.ndarray # (E,) strictly > 0
def _to_days(x_years: float) -> float:
if math.isinf(float(x_years)):
return float("inf")
return float(x_years) * DAYS_PER_YEAR
def build_event_driven_records(
subset: Subset,
age_bins_years: Sequence[float],
seed: int,
show_progress: bool = False,
n_jobs: int = 1,
chunk_size: int = 256,
prefer: str = "threads",
) -> List[EvalRecord]:
if len(age_bins_years) < 2:
raise ValueError("age_bins must have at least 2 boundaries")
age_bins_days = [_to_days(b) for b in age_bins_years]
if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)):
raise ValueError("age_bins must be strictly increasing")
def _iter_chunks(n: int, size: int) -> List[np.ndarray]:
if size <= 0:
raise ValueError("chunk_size must be >= 1")
if n == 0:
return []
idx = np.arange(n, dtype=np.int64)
return [idx[i:i + size] for i in range(0, n, size)]
def _build_records_for_index(
subset_idx: int,
*,
age_bins_days_local: Sequence[float],
rng_local: np.random.Generator,
) -> List[EvalRecord]:
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
doa_pos = np.flatnonzero(codes_ins == 1)
if doa_pos.size == 0:
raise ValueError("Expected DOA token (code=1) in event sequence")
doa_days = float(times_ins[int(doa_pos[0])])
is_disease = codes_ins >= N_TECH_TOKENS
# Lifetime (ever) disease history for Clean Control filtering.
if np.any(is_disease):
lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype(
np.int64, copy=False
)
lifetime_causes = np.unique(lifetime_causes)
else:
lifetime_causes = np.zeros((0,), dtype=np.int64)
disease_pos_all = np.flatnonzero(is_disease)
disease_times_all = (
times_ins[disease_pos_all]
if disease_pos_all.size > 0
else np.zeros((0,), dtype=np.float64)
)
eps = 1e-6
out: List[EvalRecord] = []
for b in range(len(age_bins_days_local) - 1):
lo = float(age_bins_days_local[b])
hi = float(age_bins_days_local[b + 1])
# Inclusion rule:
# 1) DOA <= bin_upper
if not (doa_days <= hi):
continue
# 2) at least one disease event within bin, and baseline must satisfy t0>=DOA.
# Random Single-Point Sampling: choose exactly one valid event *index* per (patient, age_bin).
if disease_pos_all.size == 0:
continue
in_bin = (
(disease_times_all >= lo)
& (disease_times_all < hi)
& (disease_times_all >= doa_days)
)
cand_pos = disease_pos_all[in_bin]
if cand_pos.size == 0:
continue
cutoff_pos = int(rng_local.choice(cand_pos))
t0_days = float(times_ins[cutoff_pos])
# Future disease events strictly after t0
future_mask = (times_ins > (t0_days + eps)) & is_disease
future_pos = np.flatnonzero(future_mask)
if future_pos.size == 0:
next_cause = None
next_dt_years = None
future_causes = np.zeros((0,), dtype=np.int64)
future_dt_years_arr = np.zeros((0,), dtype=np.float32)
else:
future_times_days = times_ins[future_pos]
future_tokens = codes_ins[future_pos]
future_causes = (
future_tokens - N_TECH_TOKENS).astype(np.int64)
future_dt_years_arr = (
(future_times_days - t0_days) / DAYS_PER_YEAR
).astype(np.float32)
# next-event = minimal time > t0 (tie broken by earliest position)
next_idx = int(np.argmin(future_times_days))
next_cause = int(future_causes[next_idx])
next_dt_years = float(future_dt_years_arr[next_idx])
out.append(
EvalRecord(
subset_idx=int(subset_idx),
doa_days=float(doa_days),
t0_days=float(t0_days),
cutoff_pos=int(cutoff_pos),
next_event_cause=next_cause,
next_event_dt_years=next_dt_years,
lifetime_causes=lifetime_causes,
future_causes=future_causes,
future_dt_years=future_dt_years_arr,
)
)
return out
def _process_chunk(
chunk_indices: Sequence[int],
*,
age_bins_days_local: Sequence[float],
seed_local: int,
) -> List[EvalRecord]:
out: List[EvalRecord] = []
for subset_idx in chunk_indices:
# Ensure each subject has its own deterministic RNG stream, so parallel
# workers do not share identical seeds.
ss = np.random.SeedSequence([int(seed_local), int(subset_idx)])
rng_local = np.random.default_rng(ss)
out.extend(
_build_records_for_index(
int(subset_idx),
age_bins_days_local=age_bins_days_local,
rng_local=rng_local,
)
)
return out
n = int(len(subset))
chunks = _iter_chunks(n, int(chunk_size))
do_parallel = (
int(n_jobs) != 1
and Parallel is not None
and delayed is not None
and n > 0
)
if do_parallel:
# Note: on Windows, process-based parallelism may require the underlying
# dataset to be pickleable. `prefer="threads"` is the default for safety.
parts = Parallel(n_jobs=int(n_jobs), prefer=str(prefer), batch_size=1)(
delayed(_process_chunk)(
chunk,
age_bins_days_local=age_bins_days,
seed_local=int(seed),
)
for chunk in chunks
)
records = [r for part in parts for r in part]
return records
# Sequential (preserve prior behavior/progress reporting)
rng = np.random.default_rng(seed)
records: List[EvalRecord] = []
eps = 1e-6
for subset_idx in _progress(
range(len(subset)),
enabled=show_progress,
desc="Building eval records",
total=len(subset),
):
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
doa_pos = np.flatnonzero(codes_ins == 1)
if doa_pos.size == 0:
raise ValueError("Expected DOA token (code=1) in event sequence")
doa_days = float(times_ins[int(doa_pos[0])])
is_disease = codes_ins >= N_TECH_TOKENS
if np.any(is_disease):
lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype(
np.int64, copy=False
)
lifetime_causes = np.unique(lifetime_causes)
else:
lifetime_causes = np.zeros((0,), dtype=np.int64)
disease_pos_all = np.flatnonzero(is_disease)
disease_times_all = (
times_ins[disease_pos_all]
if disease_pos_all.size > 0
else np.zeros((0,), dtype=np.float64)
)
for b in range(len(age_bins_days) - 1):
lo = age_bins_days[b]
hi = age_bins_days[b + 1]
if not (doa_days <= hi):
continue
if disease_pos_all.size == 0:
continue
in_bin = (
(disease_times_all >= lo)
& (disease_times_all < hi)
& (disease_times_all >= doa_days)
)
cand_pos = disease_pos_all[in_bin]
if cand_pos.size == 0:
continue
cutoff_pos = int(rng.choice(cand_pos))
t0_days = float(times_ins[cutoff_pos])
future_mask = (times_ins > (t0_days + eps)) & is_disease
future_pos = np.flatnonzero(future_mask)
if future_pos.size == 0:
next_cause = None
next_dt_years = None
future_causes = np.zeros((0,), dtype=np.int64)
future_dt_years_arr = np.zeros((0,), dtype=np.float32)
else:
future_times_days = times_ins[future_pos]
future_tokens = codes_ins[future_pos]
future_causes = (
future_tokens - N_TECH_TOKENS).astype(np.int64)
future_dt_years_arr = (
(future_times_days - t0_days) / DAYS_PER_YEAR
).astype(np.float32)
next_idx = int(np.argmin(future_times_days))
next_cause = int(future_causes[next_idx])
next_dt_years = float(future_dt_years_arr[next_idx])
records.append(
EvalRecord(
subset_idx=int(subset_idx),
doa_days=float(doa_days),
t0_days=float(t0_days),
cutoff_pos=int(cutoff_pos),
next_event_cause=next_cause,
next_event_dt_years=next_dt_years,
lifetime_causes=lifetime_causes,
future_causes=future_causes,
future_dt_years=future_dt_years_arr,
)
)
return records
class EvalRecordDataset(Dataset):
def __init__(self, subset: Dataset, records: Sequence[EvalRecord]):
self.subset = subset
self.records = list(records)
self._cache: Dict[int, Tuple[torch.Tensor,
torch.Tensor, torch.Tensor, torch.Tensor, int]] = {}
self._cache_order: List[int] = []
self._cache_max = 2048
def __len__(self) -> int:
return len(self.records)
def __getitem__(self, idx: int):
rec = self.records[idx]
cached = self._cache.get(rec.subset_idx)
if cached is None:
event_seq, time_seq, cont, cate, sex = self.subset[rec.subset_idx]
cached = (event_seq, time_seq, cont, cate, int(sex))
self._cache[rec.subset_idx] = cached
self._cache_order.append(rec.subset_idx)
if len(self._cache_order) > self._cache_max:
drop = self._cache_order.pop(0)
self._cache.pop(drop, None)
else:
event_seq, time_seq, cont, cate, sex = cached
cutoff = rec.cutoff_pos + 1
event_seq = event_seq[:cutoff]
time_seq = time_seq[:cutoff]
baseline_pos = rec.cutoff_pos # same index in truncated sequence
return event_seq, time_seq, cont, cate, sex, baseline_pos
def eval_collate_fn(batch):
from torch.nn.utils.rnn import pad_sequence
event_seqs, time_seqs, cont_feats, cate_feats, sexes, baseline_pos = zip(
*batch)
event_batch = pad_sequence(event_seqs, batch_first=True, padding_value=0)
time_batch = pad_sequence(
time_seqs, batch_first=True, padding_value=36525.0)
cont_batch = torch.stack(cont_feats, dim=0).unsqueeze(1)
cate_batch = torch.stack(cate_feats, dim=0).unsqueeze(1)
sex_batch = torch.tensor(sexes, dtype=torch.long)
baseline_pos = torch.tensor(baseline_pos, dtype=torch.long)
return event_batch, time_batch, cont_batch, cate_batch, sex_batch, baseline_pos
# -------------------------
# Inference utilities
# -------------------------
def predict_cifs(
model: torch.nn.Module,
head: torch.nn.Module,
criterion: torch.nn.Module,
loader: DataLoader,
taus_years: Sequence[float],
device: torch.device,
show_progress: bool = False,
progress_desc: str = "Inference",
) -> np.ndarray:
model.eval()
head.eval()
taus_t = torch.tensor(list(taus_years), dtype=torch.float32, device=device)
all_out: List[np.ndarray] = []
with torch.no_grad():
for batch in _progress(
loader,
enabled=show_progress,
desc=progress_desc,
total=len(loader) if hasattr(loader, "__len__") else None,
):
event_seq, time_seq, cont, cate, sex, baseline_pos = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
cont = cont.to(device, non_blocking=True)
cate = cate.to(device, non_blocking=True)
sex = sex.to(device, non_blocking=True)
baseline_pos = baseline_pos.to(device, non_blocking=True)
h = model(event_seq, time_seq, sex, cont, cate)
b_idx = torch.arange(h.size(0), device=device)
c = h[b_idx, baseline_pos]
logits = head(c)
cifs = criterion.calculate_cifs(logits, taus_t)
out = cifs.detach().cpu().numpy()
all_out.append(out)
return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,))
def flatten_future_events(
records: Sequence[EvalRecord],
n_causes: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Flatten (record_idx, cause, dt_years) across all future events.
Used to build horizon labels via vectorized masking + scatter.
"""
rec_idx_parts: List[np.ndarray] = []
cause_parts: List[np.ndarray] = []
dt_parts: List[np.ndarray] = []
for i, r in enumerate(records):
if r.future_causes.size == 0:
continue
causes = r.future_causes
dts = r.future_dt_years
# Keep only valid cause ids.
m = (causes >= 0) & (causes < n_causes)
if not np.any(m):
continue
causes = causes[m].astype(np.int64, copy=False)
dts = dts[m].astype(np.float32, copy=False)
rec_idx_parts.append(np.full((causes.size,), i, dtype=np.int32))
cause_parts.append(causes)
dt_parts.append(dts)
if not rec_idx_parts:
return (
np.zeros((0,), dtype=np.int32),
np.zeros((0,), dtype=np.int64),
np.zeros((0,), dtype=np.float32),
)
return (
np.concatenate(rec_idx_parts, axis=0),
np.concatenate(cause_parts, axis=0),
np.concatenate(dt_parts, axis=0),
)
# -------------------------
# Metrics helpers
# -------------------------
def roc_auc_ovr(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Binary ROC AUC with tie-aware average ranks.
Returns NaN if y_true has no positives or no negatives.
"""
y_true = np.asarray(y_true).astype(np.int32)
y_score = np.asarray(y_score).astype(np.float64)
n_pos = int(y_true.sum())
n = int(y_true.size)
n_neg = n - n_pos
if n_pos == 0 or n_neg == 0:
return float("nan")
order = np.argsort(y_score, kind="mergesort")
scores_sorted = y_score[order]
y_sorted = y_true[order]
ranks = np.empty(n, dtype=np.float64)
i = 0
while i < n:
j = i + 1
while j < n and scores_sorted[j] == scores_sorted[i]:
j += 1
# average rank for ties, ranks are 1..n
avg_rank = 0.5 * (i + 1 + j)
ranks[i:j] = avg_rank
i = j
sum_ranks_pos = float((ranks * y_sorted).sum())
auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2.0) / (n_pos * n_neg)
return float(auc)
def topk_indices(scores: np.ndarray, k: int) -> np.ndarray:
"""Return indices of top-k scores per row (descending)."""
if k <= 0:
raise ValueError("k must be positive")
n, K = scores.shape
k = min(k, K)
# argpartition gives arbitrary order within topk; sort those by score
part = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k]
part_scores = np.take_along_axis(scores, part, axis=1)
order = np.argsort(-part_scores, axis=1, kind="mergesort")
return np.take_along_axis(part, order, axis=1)
# -------------------------
# Statistical evaluation (DeLong)
# -------------------------
def compute_midrank(x: np.ndarray) -> np.ndarray:
"""Compute midranks of a 1D array (1-based ranks, tie-aware)."""
x = np.asarray(x, dtype=np.float64)
if x.ndim != 1:
raise ValueError("compute_midrank expects a 1D array")
order = np.argsort(x, kind="mergesort")
x_sorted = x[order]
n = int(x_sorted.size)
midranks = np.empty((n,), dtype=np.float64)
i = 0
while i < n:
j = i
while j < n and x_sorted[j] == x_sorted[i]:
j += 1
# ranks are 1..n; average over ties
mid = 0.5 * ((i + 1) + j)
midranks[i:j] = mid
i = j
out = np.empty((n,), dtype=np.float64)
out[order] = midranks
return out
def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]:
"""Fast DeLong method for AUC covariance.
Args:
predictions_sorted_transposed: shape (n_classifiers, n_examples), where the first
label_1_count examples are positives.
label_1_count: number of positive examples.
Returns:
(aucs, delong_cov)
"""
preds = np.asarray(predictions_sorted_transposed, dtype=np.float64)
if preds.ndim != 2:
raise ValueError("predictions_sorted_transposed must be 2D")
m = int(label_1_count)
n = int(preds.shape[1] - m)
if m <= 0 or n <= 0:
raise ValueError("DeLong requires at least 1 positive and 1 negative")
k = int(preds.shape[0])
tx = np.empty((k, m), dtype=np.float64)
ty = np.empty((k, n), dtype=np.float64)
tz = np.empty((k, m + n), dtype=np.float64)
for r in range(k):
tx[r] = compute_midrank(preds[r, :m])
ty[r] = compute_midrank(preds[r, m:])
tz[r] = compute_midrank(preds[r, :])
aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n)
v01 = (tz[:, :m] - tx) / float(n)
v10 = 1.0 - (tz[:, m:] - ty) / float(m)
# np.cov expects variables in rows by default when rowvar=True.
sx = np.cov(v01, rowvar=True, bias=False)
sy = np.cov(v10, rowvar=True, bias=False)
delong_cov = sx / float(m) + sy / float(n)
return aucs, delong_cov
def compute_ground_truth_statistics(ground_truth: np.ndarray) -> Tuple[np.ndarray, int]:
"""Return ordering that places positives first and label_1_count."""
y = np.asarray(ground_truth, dtype=np.int32)
if y.ndim != 1:
raise ValueError("ground_truth must be 1D")
label_1_count = int(y.sum())
order = np.argsort(-y, kind="mergesort")
return order, label_1_count
def get_auc_delong_var(healthy_scores: np.ndarray, diseased_scores: np.ndarray) -> Tuple[float, float]:
"""Compute AUC and its DeLong variance.
Args:
healthy_scores: scores for controls (label=0)
diseased_scores: scores for cases (label=1)
Returns:
(auc, auc_variance)
"""
h = np.asarray(healthy_scores, dtype=np.float64).reshape(-1)
d = np.asarray(diseased_scores, dtype=np.float64).reshape(-1)
n0 = int(h.size)
n1 = int(d.size)
if n0 == 0 or n1 == 0:
return float("nan"), float("nan")
# Arrange positives first as required by fastDeLong.
scores = np.concatenate([d, h], axis=0)
gt = np.concatenate([
np.ones((n1,), dtype=np.int32),
np.zeros((n0,), dtype=np.int32),
])
order, label_1_count = compute_ground_truth_statistics(gt)
preds_sorted = scores[order][None, :]
aucs, cov = fastDeLong(preds_sorted, label_1_count)
auc = float(aucs[0])
cov = np.asarray(cov)
var = float(cov[0, 0]) if cov.ndim == 2 else float(cov)
return auc, var
# -------------------------
# Next-token inference helper
# -------------------------
def predict_next_token_logits(
model: torch.nn.Module,
head: torch.nn.Module,
loader: DataLoader,
device: torch.device,
show_progress: bool = False,
progress_desc: str = "Inference (next-token)",
return_probs: bool = True,
) -> np.ndarray:
"""Predict per-cause next-token scores at baseline positions.
Returns:
np.ndarray of shape (N, K) where K is number of diseases (causes).
Notes:
- For loss types with time/bin dimensions (e.g., discrete-time CIF), this uses the
*first* time/bin (index 0) and drops the complement channel when present.
- If return_probs=True, applies softmax over causes for probability-like scores.
"""
model.eval()
head.eval()
all_out: List[np.ndarray] = []
with torch.no_grad():
for batch in _progress(
loader,
enabled=show_progress,
desc=progress_desc,
total=len(loader) if hasattr(loader, "__len__") else None,
):
event_seq, time_seq, cont, cate, sex, baseline_pos = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)
cont = cont.to(device, non_blocking=True)
cate = cate.to(device, non_blocking=True)
sex = sex.to(device, non_blocking=True)
baseline_pos = baseline_pos.to(device, non_blocking=True)
h = model(event_seq, time_seq, sex, cont, cate)
b_idx = torch.arange(h.size(0), device=device)
c = h[b_idx, baseline_pos]
logits = head(c)
# logits can be (B, K) or (B, K, T) or (B, K+1, T)
if logits.ndim == 2:
cause_logits = logits
elif logits.ndim == 3:
# Use the first time/bin.
cause_logits = logits[..., 0]
else:
raise ValueError(
f"Unsupported logits shape for next-token inference: {tuple(logits.shape)}"
)
# If a complement/survival channel exists (discrete-time CIF), drop it.
if hasattr(model, "n_disease"):
n_disease = int(getattr(model, "n_disease"))
if cause_logits.size(1) == n_disease + 1:
cause_logits = cause_logits[:, :n_disease]
elif cause_logits.size(1) > n_disease:
cause_logits = cause_logits[:, :n_disease]
if return_probs:
scores = torch.softmax(cause_logits, dim=1)
else:
scores = cause_logits
all_out.append(scores.detach().cpu().numpy())
return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,))