Files
DeepHealth/evaluate_next_event.py

296 lines
9.3 KiB
Python
Raw Normal View History

import argparse
import os
from typing import List
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
try:
from tqdm import tqdm # noqa: F401
except Exception: # pragma: no cover
tqdm = None
from utils import (
EvalRecordDataset,
build_dataset_from_config,
build_event_driven_records,
build_model_head_criterion,
eval_collate_fn,
get_test_subset,
make_inference_dataloader_kwargs,
load_checkpoint_into,
load_train_config,
parse_float_list,
predict_next_token_logits,
get_auc_delong_var,
seed_everything,
DAYS_PER_YEAR,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Evaluate next-event prediction using next-token scores"
)
p.add_argument("--run_dir", type=str, required=True)
p.add_argument(
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
)
p.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else "cpu"),
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument(
"--max_cpu_cores",
type=int,
default=-1,
help="Maximum number of CPU cores to use for parallel data construction.",
)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--min_pos",
type=int,
default=20,
help="Minimum positives for per-cause AUC",
)
p.add_argument(
"--no_tqdm",
action="store_true",
help="Disable tqdm progress bars",
)
return p.parse_args()
def _format_age_bin_label(lo: float, hi: float) -> str:
if np.isinf(hi):
return f"[{lo}, inf)"
return f"[{lo}, {hi})"
def _compute_next_event_auc_clean_control(
*,
scores: np.ndarray,
records: list,
) -> pd.DataFrame:
"""Delphi-2M next-event AUC (clean control) per cause.
Definitions per cause k:
- Case: next_event_cause == k
- Control (clean): next_event_cause != k AND k not in record.lifetime_causes
AUC is computed with DeLong variance.
"""
n_records = int(len(records))
if n_records == 0:
return pd.DataFrame(
columns=["cause_id", "n_case", "n_control", "auc", "auc_variance"],
)
K = int(scores.shape[1])
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
dtype=np.int64,
)
# Pre-compute lifetime disease membership matrix for vectorized clean-control filtering.
# lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes.
# Use a sparse matrix when SciPy is available to keep memory bounded for large K.
row_parts: List[np.ndarray] = []
col_parts: List[np.ndarray] = []
for i, r in enumerate(records):
causes = getattr(r, "lifetime_causes", None)
if causes is None:
continue
causes = np.asarray(causes, dtype=np.int64)
if causes.size == 0:
continue
# Keep only valid cause ids.
m_valid = (causes >= 0) & (causes < K)
if not np.any(m_valid):
continue
causes = causes[m_valid]
row_parts.append(np.full((causes.size,), i, dtype=np.int32))
col_parts.append(causes.astype(np.int32, copy=False))
try:
import scipy.sparse as sp # type: ignore
if row_parts:
rows = np.concatenate(row_parts, axis=0)
cols = np.concatenate(col_parts, axis=0)
data = np.ones((rows.size,), dtype=bool)
lifetime_matrix = sp.csc_matrix(
(data, (rows, cols)), shape=(n_records, K))
else:
lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool)
lifetime_is_sparse = True
except Exception: # pragma: no cover
lifetime_matrix = np.zeros((n_records, K), dtype=bool)
for rows, cols in zip(row_parts, col_parts):
lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype(
np.int64, copy=False)] = True
lifetime_is_sparse = False
auc = np.full((K,), np.nan, dtype=np.float64)
var = np.full((K,), np.nan, dtype=np.float64)
n_case = np.zeros((K,), dtype=np.int64)
n_control = np.zeros((K,), dtype=np.int64)
for k in range(K):
case_mask = y_next == k
if not np.any(case_mask):
continue
# Clean controls: not next-event k AND never had k in their lifetime history.
control_mask = y_next != k
if np.any(control_mask):
if lifetime_is_sparse:
had_k = np.asarray(lifetime_matrix.getcol(
k).toarray().ravel(), dtype=bool)
else:
had_k = lifetime_matrix[:, k]
is_clean = ~had_k
control_mask = control_mask & is_clean
cs = scores[case_mask, k]
hs = scores[control_mask, k]
n_case[k] = int(cs.size)
n_control[k] = int(hs.size)
if cs.size == 0 or hs.size == 0:
continue
a, v = get_auc_delong_var(hs, cs)
auc[k] = float(a)
var[k] = float(v)
return pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_case": n_case,
"n_control": n_control,
"auc": auc,
"auc_variance": var,
}
)
def main() -> None:
args = parse_args()
seed_everything(args.seed)
show_progress = (not args.no_tqdm)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
dataset = build_dataset_from_config(cfg)
test_subset = get_test_subset(dataset, cfg)
age_bins_years = parse_float_list(args.age_bins)
records = build_event_driven_records(
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
show_progress=show_progress,
n_jobs=int(args.max_cpu_cores),
)
device = torch.device(args.device)
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(test_subset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader(
rec_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=eval_collate_fn,
**dl_kwargs,
)
scores = predict_next_token_logits(
model,
head,
loader,
device=device,
show_progress=show_progress,
progress_desc="Inference (next-token)",
return_probs=True,
)
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
dtype=np.int64,
)
# Overall (preserve existing output files/shape)
# Strict protocol: evaluate independently per age bin (no mixing).
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
age_bins_days = age_bins_years * DAYS_PER_YEAR
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
per_bin_metric_rows: List[dict] = []
per_bin_auc_parts: List[pd.DataFrame] = []
for b in range(len(age_bins_years) - 1):
lo = float(age_bins_years[b])
hi = float(age_bins_years[b + 1])
label = _format_age_bin_label(lo, hi)
m = bin_idx == b
m_scores = scores[m]
m_records = [r for r, keep in zip(records, m) if bool(keep)]
# Coverage metric for transparency (not Delphi-2M AUC itself).
m_y = y_next[m]
n_total = int(m_y.size)
n_eligible = int((m_y >= 0).sum())
coverage = float(n_eligible / n_total) if n_total > 0 else 0.0
per_bin_metric_rows.append(
{"age_bin": label, "metric": "n_records_total", "value": n_total})
per_bin_metric_rows.append(
{"age_bin": label, "metric": "n_next_event_eligible", "value": n_eligible})
per_bin_metric_rows.append(
{"age_bin": label, "metric": "coverage", "value": coverage})
m_auc = _compute_next_event_auc_clean_control(
scores=m_scores,
records=m_records,
)
m_auc.insert(0, "age_bin", label)
per_bin_auc_parts.append(m_auc)
out_metrics_bins = os.path.join(
run_dir, "next_event_metrics_by_age_bin.csv")
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv")
if per_bin_auc_parts:
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
out_auc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "cause_id", "n_case", "n_control",
"auc", "auc_variance"]).to_csv(out_auc_bins, index=False)
print("PRIMARY METRICS: Per-cause AUC is reported per age bin using Delphi-2M clean controls.")
print("EVAL METHOD: DeLong AUC variance is reported (per cause).")
print(f"Wrote {out_metrics_bins}")
print(f"Wrote {out_auc_bins}")
if __name__ == "__main__":
main()