296 lines
9.3 KiB
Python
296 lines
9.3 KiB
Python
import argparse
|
|
import os
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
try:
|
|
from tqdm import tqdm # noqa: F401
|
|
except Exception: # pragma: no cover
|
|
tqdm = None
|
|
|
|
from utils import (
|
|
EvalRecordDataset,
|
|
build_dataset_from_config,
|
|
build_event_driven_records,
|
|
build_model_head_criterion,
|
|
eval_collate_fn,
|
|
get_test_subset,
|
|
make_inference_dataloader_kwargs,
|
|
load_checkpoint_into,
|
|
load_train_config,
|
|
parse_float_list,
|
|
predict_next_token_logits,
|
|
get_auc_delong_var,
|
|
seed_everything,
|
|
DAYS_PER_YEAR,
|
|
)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
p = argparse.ArgumentParser(
|
|
description="Evaluate next-event prediction using next-token scores"
|
|
)
|
|
p.add_argument("--run_dir", type=str, required=True)
|
|
p.add_argument(
|
|
"--age_bins",
|
|
type=str,
|
|
nargs="+",
|
|
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
|
|
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
|
|
)
|
|
|
|
p.add_argument(
|
|
"--device",
|
|
type=str,
|
|
default=("cuda" if torch.cuda.is_available() else "cpu"),
|
|
)
|
|
p.add_argument("--batch_size", type=int, default=256)
|
|
p.add_argument("--num_workers", type=int, default=0)
|
|
p.add_argument(
|
|
"--max_cpu_cores",
|
|
type=int,
|
|
default=-1,
|
|
help="Maximum number of CPU cores to use for parallel data construction.",
|
|
)
|
|
p.add_argument("--seed", type=int, default=0)
|
|
p.add_argument(
|
|
"--min_pos",
|
|
type=int,
|
|
default=20,
|
|
help="Minimum positives for per-cause AUC",
|
|
)
|
|
p.add_argument(
|
|
"--no_tqdm",
|
|
action="store_true",
|
|
help="Disable tqdm progress bars",
|
|
)
|
|
return p.parse_args()
|
|
|
|
|
|
def _format_age_bin_label(lo: float, hi: float) -> str:
|
|
if np.isinf(hi):
|
|
return f"[{lo}, inf)"
|
|
return f"[{lo}, {hi})"
|
|
|
|
|
|
def _compute_next_event_auc_clean_control(
|
|
*,
|
|
scores: np.ndarray,
|
|
records: list,
|
|
) -> pd.DataFrame:
|
|
"""Delphi-2M next-event AUC (clean control) per cause.
|
|
|
|
Definitions per cause k:
|
|
- Case: next_event_cause == k
|
|
- Control (clean): next_event_cause != k AND k not in record.lifetime_causes
|
|
AUC is computed with DeLong variance.
|
|
"""
|
|
n_records = int(len(records))
|
|
if n_records == 0:
|
|
return pd.DataFrame(
|
|
columns=["cause_id", "n_case", "n_control", "auc", "auc_variance"],
|
|
)
|
|
|
|
K = int(scores.shape[1])
|
|
y_next = np.array(
|
|
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
|
|
for r in records],
|
|
dtype=np.int64,
|
|
)
|
|
|
|
# Pre-compute lifetime disease membership matrix for vectorized clean-control filtering.
|
|
# lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes.
|
|
# Use a sparse matrix when SciPy is available to keep memory bounded for large K.
|
|
row_parts: List[np.ndarray] = []
|
|
col_parts: List[np.ndarray] = []
|
|
for i, r in enumerate(records):
|
|
causes = getattr(r, "lifetime_causes", None)
|
|
if causes is None:
|
|
continue
|
|
causes = np.asarray(causes, dtype=np.int64)
|
|
if causes.size == 0:
|
|
continue
|
|
# Keep only valid cause ids.
|
|
m_valid = (causes >= 0) & (causes < K)
|
|
if not np.any(m_valid):
|
|
continue
|
|
causes = causes[m_valid]
|
|
row_parts.append(np.full((causes.size,), i, dtype=np.int32))
|
|
col_parts.append(causes.astype(np.int32, copy=False))
|
|
|
|
try:
|
|
import scipy.sparse as sp # type: ignore
|
|
|
|
if row_parts:
|
|
rows = np.concatenate(row_parts, axis=0)
|
|
cols = np.concatenate(col_parts, axis=0)
|
|
data = np.ones((rows.size,), dtype=bool)
|
|
lifetime_matrix = sp.csc_matrix(
|
|
(data, (rows, cols)), shape=(n_records, K))
|
|
else:
|
|
lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool)
|
|
lifetime_is_sparse = True
|
|
except Exception: # pragma: no cover
|
|
lifetime_matrix = np.zeros((n_records, K), dtype=bool)
|
|
for rows, cols in zip(row_parts, col_parts):
|
|
lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype(
|
|
np.int64, copy=False)] = True
|
|
lifetime_is_sparse = False
|
|
|
|
auc = np.full((K,), np.nan, dtype=np.float64)
|
|
var = np.full((K,), np.nan, dtype=np.float64)
|
|
n_case = np.zeros((K,), dtype=np.int64)
|
|
n_control = np.zeros((K,), dtype=np.int64)
|
|
|
|
for k in range(K):
|
|
case_mask = y_next == k
|
|
if not np.any(case_mask):
|
|
continue
|
|
|
|
# Clean controls: not next-event k AND never had k in their lifetime history.
|
|
control_mask = y_next != k
|
|
if np.any(control_mask):
|
|
if lifetime_is_sparse:
|
|
had_k = np.asarray(lifetime_matrix.getcol(
|
|
k).toarray().ravel(), dtype=bool)
|
|
else:
|
|
had_k = lifetime_matrix[:, k]
|
|
is_clean = ~had_k
|
|
control_mask = control_mask & is_clean
|
|
|
|
cs = scores[case_mask, k]
|
|
hs = scores[control_mask, k]
|
|
n_case[k] = int(cs.size)
|
|
n_control[k] = int(hs.size)
|
|
if cs.size == 0 or hs.size == 0:
|
|
continue
|
|
|
|
a, v = get_auc_delong_var(hs, cs)
|
|
auc[k] = float(a)
|
|
var[k] = float(v)
|
|
|
|
return pd.DataFrame(
|
|
{
|
|
"cause_id": np.arange(K, dtype=np.int64),
|
|
"n_case": n_case,
|
|
"n_control": n_control,
|
|
"auc": auc,
|
|
"auc_variance": var,
|
|
}
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
seed_everything(args.seed)
|
|
|
|
show_progress = (not args.no_tqdm)
|
|
|
|
run_dir = args.run_dir
|
|
cfg = load_train_config(run_dir)
|
|
|
|
dataset = build_dataset_from_config(cfg)
|
|
test_subset = get_test_subset(dataset, cfg)
|
|
|
|
age_bins_years = parse_float_list(args.age_bins)
|
|
records = build_event_driven_records(
|
|
subset=test_subset,
|
|
age_bins_years=age_bins_years,
|
|
seed=args.seed,
|
|
show_progress=show_progress,
|
|
n_jobs=int(args.max_cpu_cores),
|
|
)
|
|
|
|
device = torch.device(args.device)
|
|
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
|
load_checkpoint_into(run_dir, model, head, criterion, device)
|
|
|
|
rec_ds = EvalRecordDataset(test_subset, records)
|
|
|
|
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
|
loader = DataLoader(
|
|
rec_ds,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
num_workers=args.num_workers,
|
|
collate_fn=eval_collate_fn,
|
|
**dl_kwargs,
|
|
)
|
|
|
|
scores = predict_next_token_logits(
|
|
model,
|
|
head,
|
|
loader,
|
|
device=device,
|
|
show_progress=show_progress,
|
|
progress_desc="Inference (next-token)",
|
|
return_probs=True,
|
|
)
|
|
|
|
y_next = np.array(
|
|
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
|
|
for r in records],
|
|
dtype=np.int64,
|
|
)
|
|
|
|
# Overall (preserve existing output files/shape)
|
|
# Strict protocol: evaluate independently per age bin (no mixing).
|
|
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
|
|
age_bins_days = age_bins_years * DAYS_PER_YEAR
|
|
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
|
|
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
|
|
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
|
|
|
|
per_bin_metric_rows: List[dict] = []
|
|
per_bin_auc_parts: List[pd.DataFrame] = []
|
|
for b in range(len(age_bins_years) - 1):
|
|
lo = float(age_bins_years[b])
|
|
hi = float(age_bins_years[b + 1])
|
|
label = _format_age_bin_label(lo, hi)
|
|
m = bin_idx == b
|
|
m_scores = scores[m]
|
|
m_records = [r for r, keep in zip(records, m) if bool(keep)]
|
|
# Coverage metric for transparency (not Delphi-2M AUC itself).
|
|
m_y = y_next[m]
|
|
n_total = int(m_y.size)
|
|
n_eligible = int((m_y >= 0).sum())
|
|
coverage = float(n_eligible / n_total) if n_total > 0 else 0.0
|
|
per_bin_metric_rows.append(
|
|
{"age_bin": label, "metric": "n_records_total", "value": n_total})
|
|
per_bin_metric_rows.append(
|
|
{"age_bin": label, "metric": "n_next_event_eligible", "value": n_eligible})
|
|
per_bin_metric_rows.append(
|
|
{"age_bin": label, "metric": "coverage", "value": coverage})
|
|
|
|
m_auc = _compute_next_event_auc_clean_control(
|
|
scores=m_scores,
|
|
records=m_records,
|
|
)
|
|
m_auc.insert(0, "age_bin", label)
|
|
per_bin_auc_parts.append(m_auc)
|
|
|
|
out_metrics_bins = os.path.join(
|
|
run_dir, "next_event_metrics_by_age_bin.csv")
|
|
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
|
|
|
|
out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv")
|
|
if per_bin_auc_parts:
|
|
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
|
|
out_auc_bins, index=False)
|
|
else:
|
|
pd.DataFrame(columns=["age_bin", "cause_id", "n_case", "n_control",
|
|
"auc", "auc_variance"]).to_csv(out_auc_bins, index=False)
|
|
|
|
print("PRIMARY METRICS: Per-cause AUC is reported per age bin using Delphi-2M clean controls.")
|
|
print("EVAL METHOD: DeLong AUC variance is reported (per cause).")
|
|
print(f"Wrote {out_metrics_bins}")
|
|
print(f"Wrote {out_auc_bins}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|