Files
DeepHealth/evaluate_next_event.py

312 lines
10 KiB
Python

import argparse
import os
from typing import List
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
try:
from tqdm import tqdm # noqa: F401
except Exception: # pragma: no cover
tqdm = None
from utils import (
EvalRecordDataset,
build_dataset_from_config,
build_event_driven_records,
build_model_head_criterion,
eval_collate_fn,
get_test_subset,
make_inference_dataloader_kwargs,
load_checkpoint_into,
load_train_config,
parse_float_list,
predict_next_token_logits,
get_auc_delong_var,
seed_everything,
DAYS_PER_YEAR,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Evaluate next-event prediction using next-token scores"
)
p.add_argument("--run_dir", type=str, required=True)
p.add_argument(
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
)
p.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else "cpu"),
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument(
"--max_cpu_cores",
type=int,
default=-1,
help="Maximum number of CPU cores to use for parallel data construction.",
)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--min_pos",
type=int,
default=20,
help="Minimum positives for per-cause AUC",
)
p.add_argument(
"--no_tqdm",
action="store_true",
help="Disable tqdm progress bars",
)
return p.parse_args()
def _format_age_bin_label(lo: float, hi: float) -> str:
if np.isinf(hi):
return f"[{lo}, inf)"
return f"[{lo}, {hi})"
def _compute_next_event_auc_clean_control(
*,
scores: np.ndarray,
records: list,
) -> pd.DataFrame:
"""Delphi-2M next-event AUC (clean control) per cause.
Definitions per cause k:
- Case: next_event_cause == k
- Control (clean): next_event_cause != k AND k not in record.lifetime_causes
AUC is computed with DeLong variance.
"""
n_records = int(len(records))
if n_records == 0:
return pd.DataFrame(
columns=["cause_id", "n_case", "n_control", "auc", "auc_variance"],
)
K = int(scores.shape[1])
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
dtype=np.int64,
)
# Pre-compute lifetime disease membership matrix for vectorized clean-control filtering.
# lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes.
# Use a sparse matrix when SciPy is available to keep memory bounded for large K.
row_parts: List[np.ndarray] = []
col_parts: List[np.ndarray] = []
for i, r in enumerate(records):
causes = getattr(r, "lifetime_causes", None)
if causes is None:
continue
causes = np.asarray(causes, dtype=np.int64)
if causes.size == 0:
continue
# Keep only valid cause ids.
m_valid = (causes >= 0) & (causes < K)
if not np.any(m_valid):
continue
causes = causes[m_valid]
row_parts.append(np.full((causes.size,), i, dtype=np.int32))
col_parts.append(causes.astype(np.int32, copy=False))
try:
import scipy.sparse as sp # type: ignore
if row_parts:
rows = np.concatenate(row_parts, axis=0)
cols = np.concatenate(col_parts, axis=0)
data = np.ones((rows.size,), dtype=bool)
lifetime_matrix = sp.csc_matrix(
(data, (rows, cols)), shape=(n_records, K))
else:
lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool)
lifetime_is_sparse = True
except Exception: # pragma: no cover
lifetime_matrix = np.zeros((n_records, K), dtype=bool)
for rows, cols in zip(row_parts, col_parts):
lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype(
np.int64, copy=False)] = True
lifetime_is_sparse = False
auc = np.full((K,), np.nan, dtype=np.float64)
var = np.full((K,), np.nan, dtype=np.float64)
n_case = np.zeros((K,), dtype=np.int64)
n_control = np.zeros((K,), dtype=np.int64)
for k in range(K):
case_mask = y_next == k
if not np.any(case_mask):
continue
# Clean controls: not next-event k AND never had k in their lifetime history.
control_mask = y_next != k
if np.any(control_mask):
if lifetime_is_sparse:
had_k = np.asarray(lifetime_matrix.getcol(
k).toarray().ravel(), dtype=bool)
else:
had_k = lifetime_matrix[:, k]
is_clean = ~had_k
control_mask = control_mask & is_clean
cs = scores[case_mask, k]
hs = scores[control_mask, k]
n_case[k] = int(cs.size)
n_control[k] = int(hs.size)
if cs.size == 0 or hs.size == 0:
continue
a, v = get_auc_delong_var(hs, cs)
auc[k] = float(a)
var[k] = float(v)
return pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_case": n_case,
"n_control": n_control,
"auc": auc,
"auc_variance": var,
}
)
def main() -> None:
args = parse_args()
# Best-effort control of implicit parallelism to avoid CPU oversubscription.
# Note: environment variables are ideally set before importing NumPy/PyTorch,
# but setting them early in main can still affect subprocesses or lazy readers.
if int(args.max_cpu_cores) > 0:
n_threads = int(args.max_cpu_cores)
torch.set_num_threads(n_threads)
for k in (
"OMP_NUM_THREADS",
"MKL_NUM_THREADS",
"OPENBLAS_NUM_THREADS",
"VECLIB_MAXIMUM_THREADS",
"NUMEXPR_NUM_THREADS",
):
os.environ[k] = str(n_threads)
print(f"Restricting implicit parallelism to {n_threads} threads.")
seed_everything(args.seed)
show_progress = (not args.no_tqdm)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
dataset = build_dataset_from_config(cfg)
test_subset = get_test_subset(dataset, cfg)
age_bins_years = parse_float_list(args.age_bins)
records = build_event_driven_records(
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
show_progress=show_progress,
n_jobs=int(args.max_cpu_cores),
)
device = torch.device(args.device)
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(test_subset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader(
rec_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=eval_collate_fn,
**dl_kwargs,
)
scores = predict_next_token_logits(
model,
head,
loader,
device=device,
show_progress=show_progress,
progress_desc="Inference (next-token)",
return_probs=True,
)
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
dtype=np.int64,
)
# Overall (preserve existing output files/shape)
# Strict protocol: evaluate independently per age bin (no mixing).
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
age_bins_days = age_bins_years * DAYS_PER_YEAR
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
per_bin_metric_rows: List[dict] = []
per_bin_auc_parts: List[pd.DataFrame] = []
for b in range(len(age_bins_years) - 1):
lo = float(age_bins_years[b])
hi = float(age_bins_years[b + 1])
label = _format_age_bin_label(lo, hi)
m = bin_idx == b
m_scores = scores[m]
m_records = [r for r, keep in zip(records, m) if bool(keep)]
# Coverage metric for transparency (not Delphi-2M AUC itself).
m_y = y_next[m]
n_total = int(m_y.size)
n_eligible = int((m_y >= 0).sum())
coverage = float(n_eligible / n_total) if n_total > 0 else 0.0
per_bin_metric_rows.append(
{"age_bin": label, "metric": "n_records_total", "value": n_total})
per_bin_metric_rows.append(
{"age_bin": label, "metric": "n_next_event_eligible", "value": n_eligible})
per_bin_metric_rows.append(
{"age_bin": label, "metric": "coverage", "value": coverage})
m_auc = _compute_next_event_auc_clean_control(
scores=m_scores,
records=m_records,
)
m_auc.insert(0, "age_bin", label)
per_bin_auc_parts.append(m_auc)
out_metrics_bins = os.path.join(
run_dir, "next_event_metrics_by_age_bin.csv")
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv")
if per_bin_auc_parts:
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
out_auc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "cause_id", "n_case", "n_control",
"auc", "auc_variance"]).to_csv(out_auc_bins, index=False)
print("PRIMARY METRICS: Per-cause AUC is reported per age bin using Delphi-2M clean controls.")
print("EVAL METHOD: DeLong AUC variance is reported (per cause).")
print(f"Wrote {out_metrics_bins}")
print(f"Wrote {out_auc_bins}")
if __name__ == "__main__":
main()