Remove evaluate_next_event.py and utils.py files to streamline the codebase. These files contained functions and classes related to evaluation and utility operations that are no longer needed.
This commit is contained in:
@@ -1,455 +0,0 @@
|
|||||||
"""Horizon-capture evaluation (event-driven, age-stratified).
|
|
||||||
|
|
||||||
This script implements the protocol described in 评估方案.md:
|
|
||||||
|
|
||||||
- Age-stratified evaluation: metrics are computed independently within each age bin (no mixing).
|
|
||||||
- Event-driven inclusion: each (person, age_bin) yields a record iff DOA <= bin upper bound and
|
|
||||||
there is at least one disease event in the bin; baseline t0 is sampled randomly from in-bin
|
|
||||||
disease events with t0 >= DOA.
|
|
||||||
- No follow-up completeness filtering (no t0+tau <= t_end constraint).
|
|
||||||
|
|
||||||
Primary outputs per age bin:
|
|
||||||
- Top-K Event Capture@tau (event-count based)
|
|
||||||
- Workload–Yield curves (Top-p% people by a person-level horizon score)
|
|
||||||
|
|
||||||
Secondary (diagnostic-only) outputs per age bin:
|
|
||||||
- Approximate event-driven AUC / Brier (no IPCW, no censoring adjustment)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from typing import Dict, List, Sequence, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tqdm import tqdm
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
tqdm = None
|
|
||||||
|
|
||||||
from utils import (
|
|
||||||
EvalRecordDataset,
|
|
||||||
build_dataset_from_config,
|
|
||||||
build_event_driven_records,
|
|
||||||
build_model_head_criterion,
|
|
||||||
eval_collate_fn,
|
|
||||||
flatten_future_events,
|
|
||||||
get_test_subset,
|
|
||||||
load_checkpoint_into,
|
|
||||||
load_train_config,
|
|
||||||
make_inference_dataloader_kwargs,
|
|
||||||
parse_float_list,
|
|
||||||
predict_cifs,
|
|
||||||
roc_auc_ovr,
|
|
||||||
seed_everything,
|
|
||||||
topk_indices,
|
|
||||||
DAYS_PER_YEAR,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
|
||||||
p = argparse.ArgumentParser(
|
|
||||||
description="Evaluate horizon-capture using CIF at horizons")
|
|
||||||
p.add_argument("--run_dir", type=str, required=True)
|
|
||||||
p.add_argument(
|
|
||||||
"--horizons",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
default=["0.25", "0.5", "1.0", "2.0", "5.0", "10.0"],
|
|
||||||
help="Horizon grid in years",
|
|
||||||
)
|
|
||||||
p.add_argument(
|
|
||||||
"--age_bins",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
|
|
||||||
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
|
|
||||||
)
|
|
||||||
|
|
||||||
p.add_argument(
|
|
||||||
"--device",
|
|
||||||
type=str,
|
|
||||||
default=("cuda" if torch.cuda.is_available() else "cpu"),
|
|
||||||
)
|
|
||||||
p.add_argument("--batch_size", type=int, default=256)
|
|
||||||
p.add_argument("--num_workers", type=int, default=0)
|
|
||||||
p.add_argument(
|
|
||||||
"--max_cpu_cores",
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help="Maximum number of CPU cores to use for parallel data construction.",
|
|
||||||
)
|
|
||||||
p.add_argument("--seed", type=int, default=0)
|
|
||||||
p.add_argument("--min_pos", type=int, default=20)
|
|
||||||
p.add_argument(
|
|
||||||
"--topk_list",
|
|
||||||
type=int,
|
|
||||||
nargs="+",
|
|
||||||
default=[5, 10, 20, 50],
|
|
||||||
)
|
|
||||||
p.add_argument(
|
|
||||||
"--workload_fracs",
|
|
||||||
type=float,
|
|
||||||
nargs="+",
|
|
||||||
default=[0.01, 0.02, 0.05, 0.1, 0.2, 0.5],
|
|
||||||
help="Fractions for workload–yield curves (Top-p%% people).",
|
|
||||||
)
|
|
||||||
p.add_argument(
|
|
||||||
"--no_tqdm",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable tqdm progress bars",
|
|
||||||
)
|
|
||||||
return p.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def _format_age_bin_label(lo: float, hi: float) -> str:
|
|
||||||
if np.isinf(hi):
|
|
||||||
return f"[{lo}, inf)"
|
|
||||||
return f"[{lo}, {hi})"
|
|
||||||
|
|
||||||
|
|
||||||
def _assign_age_bin_idx(t0_days: np.ndarray, age_bins_years: Sequence[float]) -> np.ndarray:
|
|
||||||
age_bins_years = np.asarray(list(age_bins_years), dtype=np.float64)
|
|
||||||
age_bins_days = age_bins_years * DAYS_PER_YEAR
|
|
||||||
return np.searchsorted(age_bins_days, t0_days, side="left") - 1
|
|
||||||
|
|
||||||
|
|
||||||
def _event_counts_within_tau(
|
|
||||||
n_records: int,
|
|
||||||
event_record_idx: np.ndarray,
|
|
||||||
event_dt_years: np.ndarray,
|
|
||||||
tau_years: float,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Count events within (t0, t0+tau] per record (event-count, not unique causes)."""
|
|
||||||
if event_record_idx.size == 0:
|
|
||||||
return np.zeros((n_records,), dtype=np.int64)
|
|
||||||
m = event_dt_years <= float(tau_years)
|
|
||||||
if not np.any(m):
|
|
||||||
return np.zeros((n_records,), dtype=np.int64)
|
|
||||||
return np.bincount(event_record_idx[m], minlength=n_records).astype(np.int64)
|
|
||||||
|
|
||||||
|
|
||||||
def build_labels_within_tau_flat(
|
|
||||||
n_records: int,
|
|
||||||
n_causes: int,
|
|
||||||
event_record_idx: np.ndarray,
|
|
||||||
event_cause: np.ndarray,
|
|
||||||
event_dt_years: np.ndarray,
|
|
||||||
tau_years: float,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Build y_within_tau using a flattened (record,cause,dt) representation.
|
|
||||||
|
|
||||||
This preserves the exact label definition: y[i,k]=1 iff at least one event of cause k
|
|
||||||
occurs in (t0, t0+tau].
|
|
||||||
"""
|
|
||||||
y = np.zeros((n_records, n_causes), dtype=np.int8)
|
|
||||||
if event_dt_years.size == 0:
|
|
||||||
return y
|
|
||||||
m = event_dt_years <= float(tau_years)
|
|
||||||
if not np.any(m):
|
|
||||||
return y
|
|
||||||
y[event_record_idx[m], event_cause[m]] = 1
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
args = parse_args()
|
|
||||||
seed_everything(args.seed)
|
|
||||||
|
|
||||||
show_progress = (not args.no_tqdm)
|
|
||||||
|
|
||||||
run_dir = args.run_dir
|
|
||||||
cfg = load_train_config(run_dir)
|
|
||||||
|
|
||||||
dataset = build_dataset_from_config(cfg)
|
|
||||||
test_subset = get_test_subset(dataset, cfg)
|
|
||||||
|
|
||||||
age_bins_years = parse_float_list(args.age_bins)
|
|
||||||
horizons = parse_float_list(args.horizons)
|
|
||||||
horizons = [float(h) for h in horizons]
|
|
||||||
|
|
||||||
records = build_event_driven_records(
|
|
||||||
subset=test_subset,
|
|
||||||
age_bins_years=age_bins_years,
|
|
||||||
seed=args.seed,
|
|
||||||
show_progress=show_progress,
|
|
||||||
n_jobs=int(args.max_cpu_cores),
|
|
||||||
)
|
|
||||||
|
|
||||||
device = torch.device(args.device)
|
|
||||||
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
|
||||||
load_checkpoint_into(run_dir, model, head, criterion, device)
|
|
||||||
|
|
||||||
rec_ds = EvalRecordDataset(test_subset, records)
|
|
||||||
|
|
||||||
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
|
||||||
loader = DataLoader(
|
|
||||||
rec_ds,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=args.num_workers,
|
|
||||||
collate_fn=eval_collate_fn,
|
|
||||||
**dl_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print disclaimers every run (requested)
|
|
||||||
print("PRIMARY METRICS: event-count Capture@K and Workload–Yield, computed independently per age bin.")
|
|
||||||
print("DIAGNOSTICS ONLY: AUC/Brier below are event-driven approximations (no IPCW / censoring adjustment).")
|
|
||||||
|
|
||||||
scores = predict_cifs(
|
|
||||||
model,
|
|
||||||
head,
|
|
||||||
criterion,
|
|
||||||
loader,
|
|
||||||
horizons,
|
|
||||||
device=device,
|
|
||||||
show_progress=show_progress,
|
|
||||||
progress_desc="Inference (horizons)",
|
|
||||||
)
|
|
||||||
# scores shape: (N, K, H)
|
|
||||||
if scores.ndim != 3:
|
|
||||||
raise ValueError(
|
|
||||||
f"Expected CIF scores with shape (N,K,H), got {scores.shape}")
|
|
||||||
|
|
||||||
N, K, H = scores.shape
|
|
||||||
if N != len(records):
|
|
||||||
raise ValueError("Record count mismatch")
|
|
||||||
|
|
||||||
# Pre-flatten all future events once to avoid repeated per-record scans.
|
|
||||||
# NOTE: these are event-level arrays (not unique causes), suitable for event-count Capture@K.
|
|
||||||
evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K)
|
|
||||||
|
|
||||||
# Assign each record to an age bin (based on t0; by construction t0 is within the bin).
|
|
||||||
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
|
|
||||||
bin_idx = _assign_age_bin_idx(t0_days, age_bins_years)
|
|
||||||
age_bins_years_arr = np.asarray(list(age_bins_years), dtype=np.float64)
|
|
||||||
|
|
||||||
capture_rows: List[Dict[str, object]] = []
|
|
||||||
workload_rows: List[Dict[str, object]] = []
|
|
||||||
|
|
||||||
# Diagnostics (optional): approximate event-driven AUC/Brier computed per bin.
|
|
||||||
diag_rows: List[Dict[str, object]] = []
|
|
||||||
diag_per_cause_parts: List[pd.DataFrame] = []
|
|
||||||
|
|
||||||
bins_iter = range(len(age_bins_years_arr) - 1)
|
|
||||||
if show_progress and tqdm is not None:
|
|
||||||
bins_iter = tqdm(bins_iter, total=len(
|
|
||||||
age_bins_years_arr) - 1, desc="Age bins")
|
|
||||||
|
|
||||||
for b in bins_iter:
|
|
||||||
lo = float(age_bins_years_arr[b])
|
|
||||||
hi = float(age_bins_years_arr[b + 1])
|
|
||||||
age_label = _format_age_bin_label(lo, hi)
|
|
||||||
|
|
||||||
m_rec = bin_idx == b
|
|
||||||
n_bin = int(m_rec.sum())
|
|
||||||
if n_bin == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
rec_idx_bin = np.flatnonzero(m_rec).astype(np.int32)
|
|
||||||
|
|
||||||
# Filter events to this bin's records once.
|
|
||||||
m_evt_bin = m_rec[evt_rec_idx] if evt_rec_idx.size > 0 else np.zeros(
|
|
||||||
(0,), dtype=bool)
|
|
||||||
evt_rec_idx_b = evt_rec_idx[m_evt_bin]
|
|
||||||
evt_cause_b = evt_cause[m_evt_bin]
|
|
||||||
evt_dt_b = evt_dt[m_evt_bin]
|
|
||||||
|
|
||||||
horizon_iter = enumerate(horizons)
|
|
||||||
if show_progress and tqdm is not None:
|
|
||||||
horizon_iter = tqdm(horizon_iter, total=len(
|
|
||||||
horizons), desc=f"Horizons {age_label}")
|
|
||||||
|
|
||||||
# Precompute a local index mapping for diagnostics label building.
|
|
||||||
local_map = np.full((N,), -1, dtype=np.int32)
|
|
||||||
local_map[rec_idx_bin] = np.arange(n_bin, dtype=np.int32)
|
|
||||||
|
|
||||||
for h_idx, tau in horizon_iter:
|
|
||||||
s_tau_all = scores[:, :, h_idx]
|
|
||||||
s_tau = s_tau_all[m_rec]
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# Primary metric: Top-K Event Capture@tau (event-count based)
|
|
||||||
# -------------------------
|
|
||||||
denom_events = int(np.sum(evt_dt_b <= float(tau))
|
|
||||||
) if evt_dt_b.size > 0 else 0
|
|
||||||
if denom_events == 0:
|
|
||||||
for topk in args.topk_list:
|
|
||||||
capture_rows.append(
|
|
||||||
{
|
|
||||||
"age_bin": age_label,
|
|
||||||
"tau_years": float(tau),
|
|
||||||
"topk": int(topk),
|
|
||||||
"capture_at_k": float("nan"),
|
|
||||||
"denom_events": int(0),
|
|
||||||
"numer_events": int(0),
|
|
||||||
"n_records": int(n_bin),
|
|
||||||
"n_causes": int(K),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
m_evt_tau = evt_dt_b <= float(tau)
|
|
||||||
evt_rec_idx_tau = evt_rec_idx_b[m_evt_tau]
|
|
||||||
evt_cause_tau = evt_cause_b[m_evt_tau]
|
|
||||||
|
|
||||||
# For each K, compute whether each event's cause is in that record's Top-K list.
|
|
||||||
for topk in args.topk_list:
|
|
||||||
topk = int(topk)
|
|
||||||
idx = topk_indices(s_tau_all, topk) # shape (N, topk)
|
|
||||||
idx_for_events = idx[evt_rec_idx_tau]
|
|
||||||
hits = (idx_for_events ==
|
|
||||||
evt_cause_tau[:, None]).any(axis=1)
|
|
||||||
numer_events = int(hits.sum())
|
|
||||||
capture = float(numer_events / denom_events)
|
|
||||||
capture_rows.append(
|
|
||||||
{
|
|
||||||
"age_bin": age_label,
|
|
||||||
"tau_years": float(tau),
|
|
||||||
"topk": int(topk),
|
|
||||||
"capture_at_k": capture,
|
|
||||||
"denom_events": int(denom_events),
|
|
||||||
"numer_events": int(numer_events),
|
|
||||||
"n_records": int(n_bin),
|
|
||||||
"n_causes": int(K),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# Primary metric: Workload–Yield (Top-p% people)
|
|
||||||
# -------------------------
|
|
||||||
# Person-level score: max_k CIF_k(tau). This is used only for workload–yield ranking.
|
|
||||||
person_score = s_tau.max(axis=1) if K > 0 else np.zeros(
|
|
||||||
(n_bin,), dtype=np.float64)
|
|
||||||
order = np.argsort(-person_score, kind="mergesort")
|
|
||||||
|
|
||||||
counts_per_record = _event_counts_within_tau(
|
|
||||||
n_bin, local_map[evt_rec_idx_b], evt_dt_b, tau)
|
|
||||||
total_events = int(counts_per_record.sum())
|
|
||||||
overall_events_per_person = (
|
|
||||||
total_events / float(n_bin)) if n_bin > 0 else float("nan")
|
|
||||||
|
|
||||||
for frac in args.workload_fracs:
|
|
||||||
frac = float(frac)
|
|
||||||
if frac <= 0.0:
|
|
||||||
continue
|
|
||||||
n_sel = int(math.ceil(frac * n_bin))
|
|
||||||
n_sel = min(max(n_sel, 1), n_bin)
|
|
||||||
sel_local = order[:n_sel]
|
|
||||||
events_captured = int(counts_per_record[sel_local].sum())
|
|
||||||
capture_rate = float(
|
|
||||||
events_captured / total_events) if total_events > 0 else float("nan")
|
|
||||||
|
|
||||||
selected_events_per_person = (
|
|
||||||
events_captured / float(n_sel)) if n_sel > 0 else float("nan")
|
|
||||||
lift = (selected_events_per_person /
|
|
||||||
overall_events_per_person) if overall_events_per_person > 0 else float("nan")
|
|
||||||
|
|
||||||
workload_rows.append(
|
|
||||||
{
|
|
||||||
"age_bin": age_label,
|
|
||||||
"tau_years": float(tau),
|
|
||||||
"frac_selected": float(frac),
|
|
||||||
"n_selected": int(n_sel),
|
|
||||||
"n_records": int(n_bin),
|
|
||||||
"total_events": int(total_events),
|
|
||||||
"events_captured": int(events_captured),
|
|
||||||
"capture_rate": capture_rate,
|
|
||||||
"lift_events_per_person": float(lift),
|
|
||||||
"person_score_def": "max_k_CIF_k(tau)",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# Diagnostics (optional): approximate event-driven AUC/Brier
|
|
||||||
# -------------------------
|
|
||||||
# Convert event-level data to binary labels y[i,k]=1 iff >=1 event of cause k within tau.
|
|
||||||
y_tau_bin = np.zeros((n_bin, K), dtype=np.int8)
|
|
||||||
if evt_dt_b.size > 0:
|
|
||||||
m_evt_tau = evt_dt_b <= float(tau)
|
|
||||||
if np.any(m_evt_tau):
|
|
||||||
rec_local = local_map[evt_rec_idx_b[m_evt_tau]]
|
|
||||||
valid = rec_local >= 0
|
|
||||||
y_tau_bin[rec_local[valid],
|
|
||||||
evt_cause_b[m_evt_tau][valid]] = 1
|
|
||||||
|
|
||||||
n_pos = y_tau_bin.sum(axis=0).astype(np.int64)
|
|
||||||
n_neg = (int(n_bin) - n_pos).astype(np.int64)
|
|
||||||
|
|
||||||
brier_per_cause = np.mean(
|
|
||||||
(y_tau_bin.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0
|
|
||||||
)
|
|
||||||
brier_macro = float(np.mean(brier_per_cause)
|
|
||||||
) if K > 0 else float("nan")
|
|
||||||
brier_weighted = float(np.sum(
|
|
||||||
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
|
|
||||||
|
|
||||||
auc = np.full((K,), np.nan, dtype=np.float64)
|
|
||||||
min_pos = int(args.min_pos)
|
|
||||||
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
|
|
||||||
for k in candidates:
|
|
||||||
auc[k] = roc_auc_ovr(y_tau_bin[:, k].astype(
|
|
||||||
np.int32), s_tau[:, k].astype(np.float64))
|
|
||||||
|
|
||||||
finite_auc = auc[np.isfinite(auc)]
|
|
||||||
auc_macro = float(np.mean(finite_auc)
|
|
||||||
) if finite_auc.size > 0 else float("nan")
|
|
||||||
w_mask = np.isfinite(auc)
|
|
||||||
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
|
|
||||||
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
|
|
||||||
n_valid_auc = int(np.isfinite(auc).sum())
|
|
||||||
|
|
||||||
diag_rows.append(
|
|
||||||
{
|
|
||||||
"age_bin": age_label,
|
|
||||||
"tau_years": float(tau),
|
|
||||||
"n_records": int(n_bin),
|
|
||||||
"n_causes": int(K),
|
|
||||||
"auc_macro": auc_macro,
|
|
||||||
"auc_weighted_by_npos": auc_weighted,
|
|
||||||
"n_causes_valid_auc": int(n_valid_auc),
|
|
||||||
"brier_macro": brier_macro,
|
|
||||||
"brier_weighted_by_npos": brier_weighted,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
diag_per_cause_parts.append(
|
|
||||||
pd.DataFrame(
|
|
||||||
{
|
|
||||||
"age_bin": age_label,
|
|
||||||
"tau_years": float(tau),
|
|
||||||
"cause_id": np.arange(K, dtype=np.int64),
|
|
||||||
"n_pos": n_pos,
|
|
||||||
"n_neg": n_neg,
|
|
||||||
"auc": auc,
|
|
||||||
"brier": brier_per_cause,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
out_capture = os.path.join(run_dir, "horizon_capture.csv")
|
|
||||||
out_wy = os.path.join(run_dir, "workload_yield.csv")
|
|
||||||
out_diag = os.path.join(run_dir, "horizon_metrics.csv")
|
|
||||||
out_diag_pc = os.path.join(run_dir, "horizon_per_cause.csv")
|
|
||||||
|
|
||||||
pd.DataFrame(capture_rows).to_csv(out_capture, index=False)
|
|
||||||
pd.DataFrame(workload_rows).to_csv(out_wy, index=False)
|
|
||||||
pd.DataFrame(diag_rows).to_csv(out_diag, index=False)
|
|
||||||
if diag_per_cause_parts:
|
|
||||||
pd.concat(diag_per_cause_parts, ignore_index=True).to_csv(
|
|
||||||
out_diag_pc, index=False)
|
|
||||||
else:
|
|
||||||
pd.DataFrame(columns=["age_bin", "tau_years", "cause_id", "n_pos",
|
|
||||||
"n_neg", "auc", "brier"]).to_csv(out_diag_pc, index=False)
|
|
||||||
|
|
||||||
print(f"Wrote {out_capture}")
|
|
||||||
print(f"Wrote {out_wy}")
|
|
||||||
print(f"Wrote {out_diag} (diagnostic-only)")
|
|
||||||
print(f"Wrote {out_diag_pc} (diagnostic-only)")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,311 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tqdm import tqdm # noqa: F401
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
tqdm = None
|
|
||||||
|
|
||||||
from utils import (
|
|
||||||
EvalRecordDataset,
|
|
||||||
build_dataset_from_config,
|
|
||||||
build_event_driven_records,
|
|
||||||
build_model_head_criterion,
|
|
||||||
eval_collate_fn,
|
|
||||||
get_test_subset,
|
|
||||||
make_inference_dataloader_kwargs,
|
|
||||||
load_checkpoint_into,
|
|
||||||
load_train_config,
|
|
||||||
parse_float_list,
|
|
||||||
predict_next_token_logits,
|
|
||||||
get_auc_delong_var,
|
|
||||||
seed_everything,
|
|
||||||
DAYS_PER_YEAR,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
|
||||||
p = argparse.ArgumentParser(
|
|
||||||
description="Evaluate next-event prediction using next-token scores"
|
|
||||||
)
|
|
||||||
p.add_argument("--run_dir", type=str, required=True)
|
|
||||||
p.add_argument(
|
|
||||||
"--age_bins",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
|
|
||||||
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
|
|
||||||
)
|
|
||||||
|
|
||||||
p.add_argument(
|
|
||||||
"--device",
|
|
||||||
type=str,
|
|
||||||
default=("cuda" if torch.cuda.is_available() else "cpu"),
|
|
||||||
)
|
|
||||||
p.add_argument("--batch_size", type=int, default=256)
|
|
||||||
p.add_argument("--num_workers", type=int, default=0)
|
|
||||||
p.add_argument(
|
|
||||||
"--max_cpu_cores",
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help="Maximum number of CPU cores to use for parallel data construction.",
|
|
||||||
)
|
|
||||||
p.add_argument("--seed", type=int, default=0)
|
|
||||||
p.add_argument(
|
|
||||||
"--min_pos",
|
|
||||||
type=int,
|
|
||||||
default=20,
|
|
||||||
help="Minimum positives for per-cause AUC",
|
|
||||||
)
|
|
||||||
p.add_argument(
|
|
||||||
"--no_tqdm",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable tqdm progress bars",
|
|
||||||
)
|
|
||||||
return p.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def _format_age_bin_label(lo: float, hi: float) -> str:
|
|
||||||
if np.isinf(hi):
|
|
||||||
return f"[{lo}, inf)"
|
|
||||||
return f"[{lo}, {hi})"
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_next_event_auc_clean_control(
|
|
||||||
*,
|
|
||||||
scores: np.ndarray,
|
|
||||||
records: list,
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
"""Delphi-2M next-event AUC (clean control) per cause.
|
|
||||||
|
|
||||||
Definitions per cause k:
|
|
||||||
- Case: next_event_cause == k
|
|
||||||
- Control (clean): next_event_cause != k AND k not in record.lifetime_causes
|
|
||||||
AUC is computed with DeLong variance.
|
|
||||||
"""
|
|
||||||
n_records = int(len(records))
|
|
||||||
if n_records == 0:
|
|
||||||
return pd.DataFrame(
|
|
||||||
columns=["cause_id", "n_case", "n_control", "auc", "auc_variance"],
|
|
||||||
)
|
|
||||||
|
|
||||||
K = int(scores.shape[1])
|
|
||||||
y_next = np.array(
|
|
||||||
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
|
|
||||||
for r in records],
|
|
||||||
dtype=np.int64,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pre-compute lifetime disease membership matrix for vectorized clean-control filtering.
|
|
||||||
# lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes.
|
|
||||||
# Use a sparse matrix when SciPy is available to keep memory bounded for large K.
|
|
||||||
row_parts: List[np.ndarray] = []
|
|
||||||
col_parts: List[np.ndarray] = []
|
|
||||||
for i, r in enumerate(records):
|
|
||||||
causes = getattr(r, "lifetime_causes", None)
|
|
||||||
if causes is None:
|
|
||||||
continue
|
|
||||||
causes = np.asarray(causes, dtype=np.int64)
|
|
||||||
if causes.size == 0:
|
|
||||||
continue
|
|
||||||
# Keep only valid cause ids.
|
|
||||||
m_valid = (causes >= 0) & (causes < K)
|
|
||||||
if not np.any(m_valid):
|
|
||||||
continue
|
|
||||||
causes = causes[m_valid]
|
|
||||||
row_parts.append(np.full((causes.size,), i, dtype=np.int32))
|
|
||||||
col_parts.append(causes.astype(np.int32, copy=False))
|
|
||||||
|
|
||||||
try:
|
|
||||||
import scipy.sparse as sp # type: ignore
|
|
||||||
|
|
||||||
if row_parts:
|
|
||||||
rows = np.concatenate(row_parts, axis=0)
|
|
||||||
cols = np.concatenate(col_parts, axis=0)
|
|
||||||
data = np.ones((rows.size,), dtype=bool)
|
|
||||||
lifetime_matrix = sp.csc_matrix(
|
|
||||||
(data, (rows, cols)), shape=(n_records, K))
|
|
||||||
else:
|
|
||||||
lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool)
|
|
||||||
lifetime_is_sparse = True
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
lifetime_matrix = np.zeros((n_records, K), dtype=bool)
|
|
||||||
for rows, cols in zip(row_parts, col_parts):
|
|
||||||
lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype(
|
|
||||||
np.int64, copy=False)] = True
|
|
||||||
lifetime_is_sparse = False
|
|
||||||
|
|
||||||
auc = np.full((K,), np.nan, dtype=np.float64)
|
|
||||||
var = np.full((K,), np.nan, dtype=np.float64)
|
|
||||||
n_case = np.zeros((K,), dtype=np.int64)
|
|
||||||
n_control = np.zeros((K,), dtype=np.int64)
|
|
||||||
|
|
||||||
for k in range(K):
|
|
||||||
case_mask = y_next == k
|
|
||||||
if not np.any(case_mask):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Clean controls: not next-event k AND never had k in their lifetime history.
|
|
||||||
control_mask = y_next != k
|
|
||||||
if np.any(control_mask):
|
|
||||||
if lifetime_is_sparse:
|
|
||||||
had_k = np.asarray(lifetime_matrix.getcol(
|
|
||||||
k).toarray().ravel(), dtype=bool)
|
|
||||||
else:
|
|
||||||
had_k = lifetime_matrix[:, k]
|
|
||||||
is_clean = ~had_k
|
|
||||||
control_mask = control_mask & is_clean
|
|
||||||
|
|
||||||
cs = scores[case_mask, k]
|
|
||||||
hs = scores[control_mask, k]
|
|
||||||
n_case[k] = int(cs.size)
|
|
||||||
n_control[k] = int(hs.size)
|
|
||||||
if cs.size == 0 or hs.size == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
a, v = get_auc_delong_var(hs, cs)
|
|
||||||
auc[k] = float(a)
|
|
||||||
var[k] = float(v)
|
|
||||||
|
|
||||||
return pd.DataFrame(
|
|
||||||
{
|
|
||||||
"cause_id": np.arange(K, dtype=np.int64),
|
|
||||||
"n_case": n_case,
|
|
||||||
"n_control": n_control,
|
|
||||||
"auc": auc,
|
|
||||||
"auc_variance": var,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
# Best-effort control of implicit parallelism to avoid CPU oversubscription.
|
|
||||||
# Note: environment variables are ideally set before importing NumPy/PyTorch,
|
|
||||||
# but setting them early in main can still affect subprocesses or lazy readers.
|
|
||||||
if int(args.max_cpu_cores) > 0:
|
|
||||||
n_threads = int(args.max_cpu_cores)
|
|
||||||
torch.set_num_threads(n_threads)
|
|
||||||
for k in (
|
|
||||||
"OMP_NUM_THREADS",
|
|
||||||
"MKL_NUM_THREADS",
|
|
||||||
"OPENBLAS_NUM_THREADS",
|
|
||||||
"VECLIB_MAXIMUM_THREADS",
|
|
||||||
"NUMEXPR_NUM_THREADS",
|
|
||||||
):
|
|
||||||
os.environ[k] = str(n_threads)
|
|
||||||
print(f"Restricting implicit parallelism to {n_threads} threads.")
|
|
||||||
seed_everything(args.seed)
|
|
||||||
|
|
||||||
show_progress = (not args.no_tqdm)
|
|
||||||
|
|
||||||
run_dir = args.run_dir
|
|
||||||
cfg = load_train_config(run_dir)
|
|
||||||
|
|
||||||
dataset = build_dataset_from_config(cfg)
|
|
||||||
test_subset = get_test_subset(dataset, cfg)
|
|
||||||
|
|
||||||
age_bins_years = parse_float_list(args.age_bins)
|
|
||||||
records = build_event_driven_records(
|
|
||||||
subset=test_subset,
|
|
||||||
age_bins_years=age_bins_years,
|
|
||||||
seed=args.seed,
|
|
||||||
show_progress=show_progress,
|
|
||||||
n_jobs=int(args.max_cpu_cores),
|
|
||||||
)
|
|
||||||
|
|
||||||
device = torch.device(args.device)
|
|
||||||
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
|
|
||||||
load_checkpoint_into(run_dir, model, head, criterion, device)
|
|
||||||
|
|
||||||
rec_ds = EvalRecordDataset(test_subset, records)
|
|
||||||
|
|
||||||
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
|
|
||||||
loader = DataLoader(
|
|
||||||
rec_ds,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=args.num_workers,
|
|
||||||
collate_fn=eval_collate_fn,
|
|
||||||
**dl_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
scores = predict_next_token_logits(
|
|
||||||
model,
|
|
||||||
head,
|
|
||||||
loader,
|
|
||||||
device=device,
|
|
||||||
show_progress=show_progress,
|
|
||||||
progress_desc="Inference (next-token)",
|
|
||||||
return_probs=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
y_next = np.array(
|
|
||||||
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
|
|
||||||
for r in records],
|
|
||||||
dtype=np.int64,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Overall (preserve existing output files/shape)
|
|
||||||
# Strict protocol: evaluate independently per age bin (no mixing).
|
|
||||||
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
|
|
||||||
age_bins_days = age_bins_years * DAYS_PER_YEAR
|
|
||||||
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
|
|
||||||
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
|
|
||||||
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
|
|
||||||
|
|
||||||
per_bin_metric_rows: List[dict] = []
|
|
||||||
per_bin_auc_parts: List[pd.DataFrame] = []
|
|
||||||
for b in range(len(age_bins_years) - 1):
|
|
||||||
lo = float(age_bins_years[b])
|
|
||||||
hi = float(age_bins_years[b + 1])
|
|
||||||
label = _format_age_bin_label(lo, hi)
|
|
||||||
m = bin_idx == b
|
|
||||||
m_scores = scores[m]
|
|
||||||
m_records = [r for r, keep in zip(records, m) if bool(keep)]
|
|
||||||
# Coverage metric for transparency (not Delphi-2M AUC itself).
|
|
||||||
m_y = y_next[m]
|
|
||||||
n_total = int(m_y.size)
|
|
||||||
n_eligible = int((m_y >= 0).sum())
|
|
||||||
coverage = float(n_eligible / n_total) if n_total > 0 else 0.0
|
|
||||||
per_bin_metric_rows.append(
|
|
||||||
{"age_bin": label, "metric": "n_records_total", "value": n_total})
|
|
||||||
per_bin_metric_rows.append(
|
|
||||||
{"age_bin": label, "metric": "n_next_event_eligible", "value": n_eligible})
|
|
||||||
per_bin_metric_rows.append(
|
|
||||||
{"age_bin": label, "metric": "coverage", "value": coverage})
|
|
||||||
|
|
||||||
m_auc = _compute_next_event_auc_clean_control(
|
|
||||||
scores=m_scores,
|
|
||||||
records=m_records,
|
|
||||||
)
|
|
||||||
m_auc.insert(0, "age_bin", label)
|
|
||||||
per_bin_auc_parts.append(m_auc)
|
|
||||||
|
|
||||||
out_metrics_bins = os.path.join(
|
|
||||||
run_dir, "next_event_metrics_by_age_bin.csv")
|
|
||||||
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
|
|
||||||
|
|
||||||
out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv")
|
|
||||||
if per_bin_auc_parts:
|
|
||||||
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
|
|
||||||
out_auc_bins, index=False)
|
|
||||||
else:
|
|
||||||
pd.DataFrame(columns=["age_bin", "cause_id", "n_case", "n_control",
|
|
||||||
"auc", "auc_variance"]).to_csv(out_auc_bins, index=False)
|
|
||||||
|
|
||||||
print("PRIMARY METRICS: Per-cause AUC is reported per age bin using Delphi-2M clean controls.")
|
|
||||||
print("EVAL METHOD: DeLong AUC variance is reported (per cause).")
|
|
||||||
print(f"Wrote {out_metrics_bins}")
|
|
||||||
print(f"Wrote {out_auc_bins}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
925
utils.py
925
utils.py
@@ -1,925 +0,0 @@
|
|||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import re
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader, Dataset, Subset, random_split
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tqdm import tqdm as _tqdm
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
_tqdm = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from joblib import Parallel, delayed # type: ignore
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
Parallel = None
|
|
||||||
delayed = None
|
|
||||||
|
|
||||||
from dataset import HealthDataset
|
|
||||||
from losses import (
|
|
||||||
DiscreteTimeCIFNLLLoss,
|
|
||||||
ExponentialNLLLoss,
|
|
||||||
PiecewiseExponentialCIFNLLLoss,
|
|
||||||
)
|
|
||||||
from model import DelphiFork, SapDelphi, SimpleHead
|
|
||||||
|
|
||||||
|
|
||||||
DAYS_PER_YEAR = 365.25
|
|
||||||
N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2
|
|
||||||
|
|
||||||
|
|
||||||
def _progress(iterable, *, enabled: bool, desc: str, total: Optional[int] = None):
|
|
||||||
if enabled and _tqdm is not None:
|
|
||||||
return _tqdm(iterable, desc=desc, total=total)
|
|
||||||
return iterable
|
|
||||||
|
|
||||||
|
|
||||||
def make_inference_dataloader_kwargs(
|
|
||||||
device: torch.device,
|
|
||||||
num_workers: int,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""DataLoader kwargs tuned for inference throughput.
|
|
||||||
|
|
||||||
Behavior/metrics are unchanged; this only impacts speed.
|
|
||||||
"""
|
|
||||||
use_cuda = device.type == "cuda" and torch.cuda.is_available()
|
|
||||||
kwargs: Dict[str, Any] = {
|
|
||||||
"pin_memory": bool(use_cuda),
|
|
||||||
}
|
|
||||||
if num_workers > 0:
|
|
||||||
kwargs["persistent_workers"] = True
|
|
||||||
# default prefetch is 2; set explicitly for clarity.
|
|
||||||
kwargs["prefetch_factor"] = 2
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# Config + determinism
|
|
||||||
# -------------------------
|
|
||||||
|
|
||||||
def _replace_nonstandard_json_numbers(text: str) -> str:
|
|
||||||
# Python's json.dump writes Infinity/-Infinity/NaN for non-finite floats.
|
|
||||||
# Replace bare tokens (not within quotes) with string placeholders.
|
|
||||||
def repl(match: re.Match[str]) -> str:
|
|
||||||
token = match.group(0)
|
|
||||||
if token == "-Infinity":
|
|
||||||
return '"__NINF__"'
|
|
||||||
if token == "Infinity":
|
|
||||||
return '"__INF__"'
|
|
||||||
if token == "NaN":
|
|
||||||
return '"__NAN__"'
|
|
||||||
return token
|
|
||||||
|
|
||||||
return re.sub(r'(?<![\w\."])(-Infinity|Infinity|NaN)(?![\w\."])', repl, text)
|
|
||||||
|
|
||||||
|
|
||||||
def _restore_placeholders(obj: Any) -> Any:
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
return {k: _restore_placeholders(v) for k, v in obj.items()}
|
|
||||||
if isinstance(obj, list):
|
|
||||||
return [_restore_placeholders(v) for v in obj]
|
|
||||||
if obj == "__INF__":
|
|
||||||
return float("inf")
|
|
||||||
if obj == "__NINF__":
|
|
||||||
return float("-inf")
|
|
||||||
if obj == "__NAN__":
|
|
||||||
return float("nan")
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def load_train_config(run_dir: str) -> Dict[str, Any]:
|
|
||||||
cfg_path = os.path.join(run_dir, "train_config.json")
|
|
||||||
with open(cfg_path, "r", encoding="utf-8") as f:
|
|
||||||
raw = f.read()
|
|
||||||
raw = _replace_nonstandard_json_numbers(raw)
|
|
||||||
cfg = json.loads(raw)
|
|
||||||
cfg = _restore_placeholders(cfg)
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def seed_everything(seed: int) -> None:
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_float_list(values: Sequence[str]) -> List[float]:
|
|
||||||
out: List[float] = []
|
|
||||||
for v in values:
|
|
||||||
s = str(v).strip().lower()
|
|
||||||
if s in {"inf", "+inf", "infty", "infinity", "+infinity"}:
|
|
||||||
out.append(float("inf"))
|
|
||||||
elif s in {"-inf", "-infty", "-infinity"}:
|
|
||||||
out.append(float("-inf"))
|
|
||||||
else:
|
|
||||||
out.append(float(v))
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# Dataset + split (match train.py)
|
|
||||||
# -------------------------
|
|
||||||
|
|
||||||
def build_dataset_from_config(cfg: Dict[str, Any]) -> HealthDataset:
|
|
||||||
data_prefix = cfg["data_prefix"]
|
|
||||||
full_cov = bool(cfg.get("full_cov", False))
|
|
||||||
|
|
||||||
if full_cov:
|
|
||||||
cov_list = None
|
|
||||||
else:
|
|
||||||
cov_list = ["bmi", "smoking", "alcohol"]
|
|
||||||
|
|
||||||
dataset = HealthDataset(
|
|
||||||
data_prefix=data_prefix,
|
|
||||||
covariate_list=cov_list,
|
|
||||||
)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def get_test_subset(dataset: HealthDataset, cfg: Dict[str, Any]) -> Subset:
|
|
||||||
n_total = len(dataset)
|
|
||||||
train_ratio = float(cfg["train_ratio"])
|
|
||||||
val_ratio = float(cfg["val_ratio"])
|
|
||||||
seed = int(cfg["random_seed"])
|
|
||||||
|
|
||||||
n_train = int(n_total * train_ratio)
|
|
||||||
n_val = int(n_total * val_ratio)
|
|
||||||
n_test = n_total - n_train - n_val
|
|
||||||
|
|
||||||
_, _, test_subset = random_split(
|
|
||||||
dataset,
|
|
||||||
[n_train, n_val, n_test],
|
|
||||||
generator=torch.Generator().manual_seed(seed),
|
|
||||||
)
|
|
||||||
return test_subset
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# Model + head + criterion (match train.py)
|
|
||||||
# -------------------------
|
|
||||||
|
|
||||||
def build_model_head_criterion(
|
|
||||||
cfg: Dict[str, Any],
|
|
||||||
dataset: HealthDataset,
|
|
||||||
device: torch.device,
|
|
||||||
) -> Tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
|
|
||||||
loss_type = cfg["loss_type"]
|
|
||||||
|
|
||||||
if loss_type == "exponential":
|
|
||||||
criterion = ExponentialNLLLoss(lambda_reg=float(
|
|
||||||
cfg.get("lambda_reg", 0.0))).to(device)
|
|
||||||
out_dims = [dataset.n_disease]
|
|
||||||
elif loss_type == "discrete_time_cif":
|
|
||||||
bin_edges = [float(x) for x in cfg["bin_edges"]]
|
|
||||||
criterion = DiscreteTimeCIFNLLLoss(
|
|
||||||
bin_edges=bin_edges,
|
|
||||||
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
|
|
||||||
).to(device)
|
|
||||||
out_dims = [dataset.n_disease + 1, len(bin_edges)]
|
|
||||||
elif loss_type == "pwe_cif":
|
|
||||||
# training drops +inf for PWE
|
|
||||||
raw_edges = [float(x) for x in cfg["bin_edges"]]
|
|
||||||
pwe_edges = [float(x) for x in raw_edges if math.isfinite(float(x))]
|
|
||||||
if len(pwe_edges) < 2:
|
|
||||||
raise ValueError(
|
|
||||||
"pwe_cif requires at least 2 finite bin edges (including 0). "
|
|
||||||
f"Got bin_edges={raw_edges}"
|
|
||||||
)
|
|
||||||
if float(pwe_edges[0]) != 0.0:
|
|
||||||
raise ValueError(
|
|
||||||
f"pwe_cif requires bin_edges[0]==0.0; got {pwe_edges[0]}")
|
|
||||||
|
|
||||||
criterion = PiecewiseExponentialCIFNLLLoss(
|
|
||||||
bin_edges=pwe_edges,
|
|
||||||
lambda_reg=float(cfg.get("lambda_reg", 0.0)),
|
|
||||||
).to(device)
|
|
||||||
n_bins = len(pwe_edges) - 1
|
|
||||||
out_dims = [dataset.n_disease, n_bins]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported loss_type: {loss_type}")
|
|
||||||
|
|
||||||
model_type = cfg["model_type"]
|
|
||||||
if model_type == "delphi_fork":
|
|
||||||
model = DelphiFork(
|
|
||||||
n_disease=dataset.n_disease,
|
|
||||||
n_tech_tokens=N_TECH_TOKENS,
|
|
||||||
n_embd=int(cfg["n_embd"]),
|
|
||||||
n_head=int(cfg["n_head"]),
|
|
||||||
n_layer=int(cfg["n_layer"]),
|
|
||||||
pdrop=float(cfg.get("pdrop", 0.0)),
|
|
||||||
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
|
|
||||||
n_cont=int(dataset.n_cont),
|
|
||||||
n_cate=int(dataset.n_cate),
|
|
||||||
cate_dims=list(dataset.cate_dims),
|
|
||||||
).to(device)
|
|
||||||
elif model_type == "sap_delphi":
|
|
||||||
model = SapDelphi(
|
|
||||||
n_disease=dataset.n_disease,
|
|
||||||
n_tech_tokens=N_TECH_TOKENS,
|
|
||||||
n_embd=int(cfg["n_embd"]),
|
|
||||||
n_head=int(cfg["n_head"]),
|
|
||||||
n_layer=int(cfg["n_layer"]),
|
|
||||||
pdrop=float(cfg.get("pdrop", 0.0)),
|
|
||||||
age_encoder_type=str(cfg.get("age_encoder", "sinusoidal")),
|
|
||||||
n_cont=int(dataset.n_cont),
|
|
||||||
n_cate=int(dataset.n_cate),
|
|
||||||
cate_dims=list(dataset.cate_dims),
|
|
||||||
pretrained_weights_path=str(
|
|
||||||
cfg.get("pretrained_emd_path", "icd10_sapbert_embeddings.npy")),
|
|
||||||
freeze_embeddings=True,
|
|
||||||
).to(device)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported model_type: {model_type}")
|
|
||||||
|
|
||||||
head = SimpleHead(
|
|
||||||
n_embd=int(cfg["n_embd"]),
|
|
||||||
out_dims=list(out_dims),
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
return model, head, criterion
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_into(
|
|
||||||
run_dir: str,
|
|
||||||
model: torch.nn.Module,
|
|
||||||
head: torch.nn.Module,
|
|
||||||
criterion: Optional[torch.nn.Module],
|
|
||||||
device: torch.device,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
ckpt_path = os.path.join(run_dir, "best_model.pt")
|
|
||||||
ckpt = torch.load(ckpt_path, map_location=device)
|
|
||||||
model.load_state_dict(ckpt["model_state_dict"], strict=True)
|
|
||||||
head.load_state_dict(ckpt["head_state_dict"], strict=True)
|
|
||||||
if criterion is not None and "criterion_state_dict" in ckpt:
|
|
||||||
try:
|
|
||||||
criterion.load_state_dict(
|
|
||||||
ckpt["criterion_state_dict"], strict=False)
|
|
||||||
except Exception:
|
|
||||||
# Criterion state is not essential for inference.
|
|
||||||
pass
|
|
||||||
return ckpt
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# Evaluation record construction (event-driven)
|
|
||||||
# -------------------------
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class EvalRecord:
|
|
||||||
subset_idx: int
|
|
||||||
doa_days: float
|
|
||||||
t0_days: float
|
|
||||||
cutoff_pos: int # baseline position (inclusive)
|
|
||||||
next_event_cause: Optional[int]
|
|
||||||
next_event_dt_years: Optional[float]
|
|
||||||
# (U,) unique causes ever observed (clean-control filtering)
|
|
||||||
lifetime_causes: np.ndarray
|
|
||||||
future_causes: np.ndarray # (E,) in [0..K-1]
|
|
||||||
future_dt_years: np.ndarray # (E,) strictly > 0
|
|
||||||
|
|
||||||
|
|
||||||
def _to_days(x_years: float) -> float:
|
|
||||||
if math.isinf(float(x_years)):
|
|
||||||
return float("inf")
|
|
||||||
return float(x_years) * DAYS_PER_YEAR
|
|
||||||
|
|
||||||
|
|
||||||
def build_event_driven_records(
|
|
||||||
subset: Subset,
|
|
||||||
age_bins_years: Sequence[float],
|
|
||||||
seed: int,
|
|
||||||
show_progress: bool = False,
|
|
||||||
n_jobs: int = 1,
|
|
||||||
chunk_size: int = 256,
|
|
||||||
prefer: str = "threads",
|
|
||||||
) -> List[EvalRecord]:
|
|
||||||
if len(age_bins_years) < 2:
|
|
||||||
raise ValueError("age_bins must have at least 2 boundaries")
|
|
||||||
|
|
||||||
age_bins_days = [_to_days(b) for b in age_bins_years]
|
|
||||||
if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)):
|
|
||||||
raise ValueError("age_bins must be strictly increasing")
|
|
||||||
|
|
||||||
def _iter_chunks(n: int, size: int) -> List[np.ndarray]:
|
|
||||||
if size <= 0:
|
|
||||||
raise ValueError("chunk_size must be >= 1")
|
|
||||||
if n == 0:
|
|
||||||
return []
|
|
||||||
idx = np.arange(n, dtype=np.int64)
|
|
||||||
return [idx[i:i + size] for i in range(0, n, size)]
|
|
||||||
|
|
||||||
def _build_records_for_index(
|
|
||||||
subset_idx: int,
|
|
||||||
*,
|
|
||||||
age_bins_days_local: Sequence[float],
|
|
||||||
rng_local: np.random.Generator,
|
|
||||||
) -> List[EvalRecord]:
|
|
||||||
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
|
|
||||||
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
|
|
||||||
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
|
|
||||||
|
|
||||||
doa_pos = np.flatnonzero(codes_ins == 1)
|
|
||||||
if doa_pos.size == 0:
|
|
||||||
raise ValueError("Expected DOA token (code=1) in event sequence")
|
|
||||||
doa_days = float(times_ins[int(doa_pos[0])])
|
|
||||||
|
|
||||||
is_disease = codes_ins >= N_TECH_TOKENS
|
|
||||||
|
|
||||||
# Lifetime (ever) disease history for Clean Control filtering.
|
|
||||||
if np.any(is_disease):
|
|
||||||
lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype(
|
|
||||||
np.int64, copy=False
|
|
||||||
)
|
|
||||||
lifetime_causes = np.unique(lifetime_causes)
|
|
||||||
else:
|
|
||||||
lifetime_causes = np.zeros((0,), dtype=np.int64)
|
|
||||||
|
|
||||||
disease_pos_all = np.flatnonzero(is_disease)
|
|
||||||
disease_times_all = (
|
|
||||||
times_ins[disease_pos_all]
|
|
||||||
if disease_pos_all.size > 0
|
|
||||||
else np.zeros((0,), dtype=np.float64)
|
|
||||||
)
|
|
||||||
|
|
||||||
eps = 1e-6
|
|
||||||
out: List[EvalRecord] = []
|
|
||||||
for b in range(len(age_bins_days_local) - 1):
|
|
||||||
lo = float(age_bins_days_local[b])
|
|
||||||
hi = float(age_bins_days_local[b + 1])
|
|
||||||
|
|
||||||
# Inclusion rule:
|
|
||||||
# 1) DOA <= bin_upper
|
|
||||||
if not (doa_days <= hi):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 2) at least one disease event within bin, and baseline must satisfy t0>=DOA.
|
|
||||||
# Random Single-Point Sampling: choose exactly one valid event *index* per (patient, age_bin).
|
|
||||||
if disease_pos_all.size == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
in_bin = (
|
|
||||||
(disease_times_all >= lo)
|
|
||||||
& (disease_times_all < hi)
|
|
||||||
& (disease_times_all >= doa_days)
|
|
||||||
)
|
|
||||||
cand_pos = disease_pos_all[in_bin]
|
|
||||||
if cand_pos.size == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
cutoff_pos = int(rng_local.choice(cand_pos))
|
|
||||||
t0_days = float(times_ins[cutoff_pos])
|
|
||||||
|
|
||||||
# Future disease events strictly after t0
|
|
||||||
future_mask = (times_ins > (t0_days + eps)) & is_disease
|
|
||||||
future_pos = np.flatnonzero(future_mask)
|
|
||||||
if future_pos.size == 0:
|
|
||||||
next_cause = None
|
|
||||||
next_dt_years = None
|
|
||||||
future_causes = np.zeros((0,), dtype=np.int64)
|
|
||||||
future_dt_years_arr = np.zeros((0,), dtype=np.float32)
|
|
||||||
else:
|
|
||||||
future_times_days = times_ins[future_pos]
|
|
||||||
future_tokens = codes_ins[future_pos]
|
|
||||||
future_causes = (
|
|
||||||
future_tokens - N_TECH_TOKENS).astype(np.int64)
|
|
||||||
future_dt_years_arr = (
|
|
||||||
(future_times_days - t0_days) / DAYS_PER_YEAR
|
|
||||||
).astype(np.float32)
|
|
||||||
|
|
||||||
# next-event = minimal time > t0 (tie broken by earliest position)
|
|
||||||
next_idx = int(np.argmin(future_times_days))
|
|
||||||
next_cause = int(future_causes[next_idx])
|
|
||||||
next_dt_years = float(future_dt_years_arr[next_idx])
|
|
||||||
|
|
||||||
out.append(
|
|
||||||
EvalRecord(
|
|
||||||
subset_idx=int(subset_idx),
|
|
||||||
doa_days=float(doa_days),
|
|
||||||
t0_days=float(t0_days),
|
|
||||||
cutoff_pos=int(cutoff_pos),
|
|
||||||
next_event_cause=next_cause,
|
|
||||||
next_event_dt_years=next_dt_years,
|
|
||||||
lifetime_causes=lifetime_causes,
|
|
||||||
future_causes=future_causes,
|
|
||||||
future_dt_years=future_dt_years_arr,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _process_chunk(
|
|
||||||
chunk_indices: Sequence[int],
|
|
||||||
*,
|
|
||||||
age_bins_days_local: Sequence[float],
|
|
||||||
seed_local: int,
|
|
||||||
) -> List[EvalRecord]:
|
|
||||||
out: List[EvalRecord] = []
|
|
||||||
for subset_idx in chunk_indices:
|
|
||||||
# Ensure each subject has its own deterministic RNG stream, so parallel
|
|
||||||
# workers do not share identical seeds.
|
|
||||||
ss = np.random.SeedSequence([int(seed_local), int(subset_idx)])
|
|
||||||
rng_local = np.random.default_rng(ss)
|
|
||||||
out.extend(
|
|
||||||
_build_records_for_index(
|
|
||||||
int(subset_idx),
|
|
||||||
age_bins_days_local=age_bins_days_local,
|
|
||||||
rng_local=rng_local,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
n = int(len(subset))
|
|
||||||
chunks = _iter_chunks(n, int(chunk_size))
|
|
||||||
|
|
||||||
do_parallel = (
|
|
||||||
int(n_jobs) != 1
|
|
||||||
and Parallel is not None
|
|
||||||
and delayed is not None
|
|
||||||
and n > 0
|
|
||||||
)
|
|
||||||
|
|
||||||
if do_parallel:
|
|
||||||
# Note: on Windows, process-based parallelism may require the underlying
|
|
||||||
# dataset to be pickleable. `prefer="threads"` is the default for safety.
|
|
||||||
parts = Parallel(n_jobs=int(n_jobs), prefer=str(prefer), batch_size=1)(
|
|
||||||
delayed(_process_chunk)(
|
|
||||||
chunk,
|
|
||||||
age_bins_days_local=age_bins_days,
|
|
||||||
seed_local=int(seed),
|
|
||||||
)
|
|
||||||
for chunk in chunks
|
|
||||||
)
|
|
||||||
records = [r for part in parts for r in part]
|
|
||||||
return records
|
|
||||||
|
|
||||||
# Sequential (preserve prior behavior/progress reporting)
|
|
||||||
rng = np.random.default_rng(seed)
|
|
||||||
records: List[EvalRecord] = []
|
|
||||||
eps = 1e-6
|
|
||||||
for subset_idx in _progress(
|
|
||||||
range(len(subset)),
|
|
||||||
enabled=show_progress,
|
|
||||||
desc="Building eval records",
|
|
||||||
total=len(subset),
|
|
||||||
):
|
|
||||||
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
|
|
||||||
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
|
|
||||||
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
|
|
||||||
|
|
||||||
doa_pos = np.flatnonzero(codes_ins == 1)
|
|
||||||
if doa_pos.size == 0:
|
|
||||||
raise ValueError("Expected DOA token (code=1) in event sequence")
|
|
||||||
doa_days = float(times_ins[int(doa_pos[0])])
|
|
||||||
|
|
||||||
is_disease = codes_ins >= N_TECH_TOKENS
|
|
||||||
|
|
||||||
if np.any(is_disease):
|
|
||||||
lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype(
|
|
||||||
np.int64, copy=False
|
|
||||||
)
|
|
||||||
lifetime_causes = np.unique(lifetime_causes)
|
|
||||||
else:
|
|
||||||
lifetime_causes = np.zeros((0,), dtype=np.int64)
|
|
||||||
|
|
||||||
disease_pos_all = np.flatnonzero(is_disease)
|
|
||||||
disease_times_all = (
|
|
||||||
times_ins[disease_pos_all]
|
|
||||||
if disease_pos_all.size > 0
|
|
||||||
else np.zeros((0,), dtype=np.float64)
|
|
||||||
)
|
|
||||||
|
|
||||||
for b in range(len(age_bins_days) - 1):
|
|
||||||
lo = age_bins_days[b]
|
|
||||||
hi = age_bins_days[b + 1]
|
|
||||||
if not (doa_days <= hi):
|
|
||||||
continue
|
|
||||||
if disease_pos_all.size == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
in_bin = (
|
|
||||||
(disease_times_all >= lo)
|
|
||||||
& (disease_times_all < hi)
|
|
||||||
& (disease_times_all >= doa_days)
|
|
||||||
)
|
|
||||||
cand_pos = disease_pos_all[in_bin]
|
|
||||||
if cand_pos.size == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
cutoff_pos = int(rng.choice(cand_pos))
|
|
||||||
t0_days = float(times_ins[cutoff_pos])
|
|
||||||
|
|
||||||
future_mask = (times_ins > (t0_days + eps)) & is_disease
|
|
||||||
future_pos = np.flatnonzero(future_mask)
|
|
||||||
if future_pos.size == 0:
|
|
||||||
next_cause = None
|
|
||||||
next_dt_years = None
|
|
||||||
future_causes = np.zeros((0,), dtype=np.int64)
|
|
||||||
future_dt_years_arr = np.zeros((0,), dtype=np.float32)
|
|
||||||
else:
|
|
||||||
future_times_days = times_ins[future_pos]
|
|
||||||
future_tokens = codes_ins[future_pos]
|
|
||||||
future_causes = (
|
|
||||||
future_tokens - N_TECH_TOKENS).astype(np.int64)
|
|
||||||
future_dt_years_arr = (
|
|
||||||
(future_times_days - t0_days) / DAYS_PER_YEAR
|
|
||||||
).astype(np.float32)
|
|
||||||
|
|
||||||
next_idx = int(np.argmin(future_times_days))
|
|
||||||
next_cause = int(future_causes[next_idx])
|
|
||||||
next_dt_years = float(future_dt_years_arr[next_idx])
|
|
||||||
|
|
||||||
records.append(
|
|
||||||
EvalRecord(
|
|
||||||
subset_idx=int(subset_idx),
|
|
||||||
doa_days=float(doa_days),
|
|
||||||
t0_days=float(t0_days),
|
|
||||||
cutoff_pos=int(cutoff_pos),
|
|
||||||
next_event_cause=next_cause,
|
|
||||||
next_event_dt_years=next_dt_years,
|
|
||||||
lifetime_causes=lifetime_causes,
|
|
||||||
future_causes=future_causes,
|
|
||||||
future_dt_years=future_dt_years_arr,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return records
|
|
||||||
|
|
||||||
|
|
||||||
class EvalRecordDataset(Dataset):
|
|
||||||
def __init__(self, subset: Dataset, records: Sequence[EvalRecord]):
|
|
||||||
self.subset = subset
|
|
||||||
self.records = list(records)
|
|
||||||
self._cache: Dict[int, Tuple[torch.Tensor,
|
|
||||||
torch.Tensor, torch.Tensor, torch.Tensor, int]] = {}
|
|
||||||
self._cache_order: List[int] = []
|
|
||||||
self._cache_max = 2048
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.records)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
|
||||||
rec = self.records[idx]
|
|
||||||
cached = self._cache.get(rec.subset_idx)
|
|
||||||
if cached is None:
|
|
||||||
event_seq, time_seq, cont, cate, sex = self.subset[rec.subset_idx]
|
|
||||||
cached = (event_seq, time_seq, cont, cate, int(sex))
|
|
||||||
self._cache[rec.subset_idx] = cached
|
|
||||||
self._cache_order.append(rec.subset_idx)
|
|
||||||
if len(self._cache_order) > self._cache_max:
|
|
||||||
drop = self._cache_order.pop(0)
|
|
||||||
self._cache.pop(drop, None)
|
|
||||||
else:
|
|
||||||
event_seq, time_seq, cont, cate, sex = cached
|
|
||||||
cutoff = rec.cutoff_pos + 1
|
|
||||||
event_seq = event_seq[:cutoff]
|
|
||||||
time_seq = time_seq[:cutoff]
|
|
||||||
baseline_pos = rec.cutoff_pos # same index in truncated sequence
|
|
||||||
return event_seq, time_seq, cont, cate, sex, baseline_pos
|
|
||||||
|
|
||||||
|
|
||||||
def eval_collate_fn(batch):
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
|
|
||||||
event_seqs, time_seqs, cont_feats, cate_feats, sexes, baseline_pos = zip(
|
|
||||||
*batch)
|
|
||||||
event_batch = pad_sequence(event_seqs, batch_first=True, padding_value=0)
|
|
||||||
time_batch = pad_sequence(
|
|
||||||
time_seqs, batch_first=True, padding_value=36525.0)
|
|
||||||
cont_batch = torch.stack(cont_feats, dim=0).unsqueeze(1)
|
|
||||||
cate_batch = torch.stack(cate_feats, dim=0).unsqueeze(1)
|
|
||||||
sex_batch = torch.tensor(sexes, dtype=torch.long)
|
|
||||||
baseline_pos = torch.tensor(baseline_pos, dtype=torch.long)
|
|
||||||
return event_batch, time_batch, cont_batch, cate_batch, sex_batch, baseline_pos
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# Inference utilities
|
|
||||||
# -------------------------
|
|
||||||
|
|
||||||
def predict_cifs(
|
|
||||||
model: torch.nn.Module,
|
|
||||||
head: torch.nn.Module,
|
|
||||||
criterion: torch.nn.Module,
|
|
||||||
loader: DataLoader,
|
|
||||||
taus_years: Sequence[float],
|
|
||||||
device: torch.device,
|
|
||||||
show_progress: bool = False,
|
|
||||||
progress_desc: str = "Inference",
|
|
||||||
) -> np.ndarray:
|
|
||||||
model.eval()
|
|
||||||
head.eval()
|
|
||||||
|
|
||||||
taus_t = torch.tensor(list(taus_years), dtype=torch.float32, device=device)
|
|
||||||
|
|
||||||
all_out: List[np.ndarray] = []
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch in _progress(
|
|
||||||
loader,
|
|
||||||
enabled=show_progress,
|
|
||||||
desc=progress_desc,
|
|
||||||
total=len(loader) if hasattr(loader, "__len__") else None,
|
|
||||||
):
|
|
||||||
event_seq, time_seq, cont, cate, sex, baseline_pos = batch
|
|
||||||
event_seq = event_seq.to(device, non_blocking=True)
|
|
||||||
time_seq = time_seq.to(device, non_blocking=True)
|
|
||||||
cont = cont.to(device, non_blocking=True)
|
|
||||||
cate = cate.to(device, non_blocking=True)
|
|
||||||
sex = sex.to(device, non_blocking=True)
|
|
||||||
baseline_pos = baseline_pos.to(device, non_blocking=True)
|
|
||||||
|
|
||||||
h = model(event_seq, time_seq, sex, cont, cate)
|
|
||||||
b_idx = torch.arange(h.size(0), device=device)
|
|
||||||
c = h[b_idx, baseline_pos]
|
|
||||||
logits = head(c)
|
|
||||||
|
|
||||||
cifs = criterion.calculate_cifs(logits, taus_t)
|
|
||||||
out = cifs.detach().cpu().numpy()
|
|
||||||
all_out.append(out)
|
|
||||||
|
|
||||||
return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,))
|
|
||||||
|
|
||||||
|
|
||||||
def flatten_future_events(
|
|
||||||
records: Sequence[EvalRecord],
|
|
||||||
n_causes: int,
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
"""Flatten (record_idx, cause, dt_years) across all future events.
|
|
||||||
|
|
||||||
Used to build horizon labels via vectorized masking + scatter.
|
|
||||||
"""
|
|
||||||
rec_idx_parts: List[np.ndarray] = []
|
|
||||||
cause_parts: List[np.ndarray] = []
|
|
||||||
dt_parts: List[np.ndarray] = []
|
|
||||||
|
|
||||||
for i, r in enumerate(records):
|
|
||||||
if r.future_causes.size == 0:
|
|
||||||
continue
|
|
||||||
causes = r.future_causes
|
|
||||||
dts = r.future_dt_years
|
|
||||||
# Keep only valid cause ids.
|
|
||||||
m = (causes >= 0) & (causes < n_causes)
|
|
||||||
if not np.any(m):
|
|
||||||
continue
|
|
||||||
causes = causes[m].astype(np.int64, copy=False)
|
|
||||||
dts = dts[m].astype(np.float32, copy=False)
|
|
||||||
rec_idx_parts.append(np.full((causes.size,), i, dtype=np.int32))
|
|
||||||
cause_parts.append(causes)
|
|
||||||
dt_parts.append(dts)
|
|
||||||
|
|
||||||
if not rec_idx_parts:
|
|
||||||
return (
|
|
||||||
np.zeros((0,), dtype=np.int32),
|
|
||||||
np.zeros((0,), dtype=np.int64),
|
|
||||||
np.zeros((0,), dtype=np.float32),
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
np.concatenate(rec_idx_parts, axis=0),
|
|
||||||
np.concatenate(cause_parts, axis=0),
|
|
||||||
np.concatenate(dt_parts, axis=0),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# Metrics helpers
|
|
||||||
# -------------------------
|
|
||||||
|
|
||||||
def roc_auc_ovr(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
|
||||||
"""Binary ROC AUC with tie-aware average ranks.
|
|
||||||
|
|
||||||
Returns NaN if y_true has no positives or no negatives.
|
|
||||||
"""
|
|
||||||
y_true = np.asarray(y_true).astype(np.int32)
|
|
||||||
y_score = np.asarray(y_score).astype(np.float64)
|
|
||||||
|
|
||||||
n_pos = int(y_true.sum())
|
|
||||||
n = int(y_true.size)
|
|
||||||
n_neg = n - n_pos
|
|
||||||
if n_pos == 0 or n_neg == 0:
|
|
||||||
return float("nan")
|
|
||||||
|
|
||||||
order = np.argsort(y_score, kind="mergesort")
|
|
||||||
scores_sorted = y_score[order]
|
|
||||||
y_sorted = y_true[order]
|
|
||||||
|
|
||||||
ranks = np.empty(n, dtype=np.float64)
|
|
||||||
i = 0
|
|
||||||
while i < n:
|
|
||||||
j = i + 1
|
|
||||||
while j < n and scores_sorted[j] == scores_sorted[i]:
|
|
||||||
j += 1
|
|
||||||
# average rank for ties, ranks are 1..n
|
|
||||||
avg_rank = 0.5 * (i + 1 + j)
|
|
||||||
ranks[i:j] = avg_rank
|
|
||||||
i = j
|
|
||||||
|
|
||||||
sum_ranks_pos = float((ranks * y_sorted).sum())
|
|
||||||
auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2.0) / (n_pos * n_neg)
|
|
||||||
return float(auc)
|
|
||||||
|
|
||||||
|
|
||||||
def topk_indices(scores: np.ndarray, k: int) -> np.ndarray:
|
|
||||||
"""Return indices of top-k scores per row (descending)."""
|
|
||||||
if k <= 0:
|
|
||||||
raise ValueError("k must be positive")
|
|
||||||
n, K = scores.shape
|
|
||||||
k = min(k, K)
|
|
||||||
# argpartition gives arbitrary order within topk; sort those by score
|
|
||||||
part = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k]
|
|
||||||
part_scores = np.take_along_axis(scores, part, axis=1)
|
|
||||||
order = np.argsort(-part_scores, axis=1, kind="mergesort")
|
|
||||||
return np.take_along_axis(part, order, axis=1)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# Statistical evaluation (DeLong)
|
|
||||||
# -------------------------
|
|
||||||
|
|
||||||
def compute_midrank(x: np.ndarray) -> np.ndarray:
|
|
||||||
"""Compute midranks of a 1D array (1-based ranks, tie-aware)."""
|
|
||||||
x = np.asarray(x, dtype=np.float64)
|
|
||||||
if x.ndim != 1:
|
|
||||||
raise ValueError("compute_midrank expects a 1D array")
|
|
||||||
|
|
||||||
order = np.argsort(x, kind="mergesort")
|
|
||||||
x_sorted = x[order]
|
|
||||||
n = int(x_sorted.size)
|
|
||||||
|
|
||||||
midranks = np.empty((n,), dtype=np.float64)
|
|
||||||
i = 0
|
|
||||||
while i < n:
|
|
||||||
j = i
|
|
||||||
while j < n and x_sorted[j] == x_sorted[i]:
|
|
||||||
j += 1
|
|
||||||
# ranks are 1..n; average over ties
|
|
||||||
mid = 0.5 * ((i + 1) + j)
|
|
||||||
midranks[i:j] = mid
|
|
||||||
i = j
|
|
||||||
|
|
||||||
out = np.empty((n,), dtype=np.float64)
|
|
||||||
out[order] = midranks
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""Fast DeLong method for AUC covariance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
predictions_sorted_transposed: shape (n_classifiers, n_examples), where the first
|
|
||||||
label_1_count examples are positives.
|
|
||||||
label_1_count: number of positive examples.
|
|
||||||
Returns:
|
|
||||||
(aucs, delong_cov)
|
|
||||||
"""
|
|
||||||
preds = np.asarray(predictions_sorted_transposed, dtype=np.float64)
|
|
||||||
if preds.ndim != 2:
|
|
||||||
raise ValueError("predictions_sorted_transposed must be 2D")
|
|
||||||
|
|
||||||
m = int(label_1_count)
|
|
||||||
n = int(preds.shape[1] - m)
|
|
||||||
if m <= 0 or n <= 0:
|
|
||||||
raise ValueError("DeLong requires at least 1 positive and 1 negative")
|
|
||||||
|
|
||||||
k = int(preds.shape[0])
|
|
||||||
tx = np.empty((k, m), dtype=np.float64)
|
|
||||||
ty = np.empty((k, n), dtype=np.float64)
|
|
||||||
tz = np.empty((k, m + n), dtype=np.float64)
|
|
||||||
|
|
||||||
for r in range(k):
|
|
||||||
tx[r] = compute_midrank(preds[r, :m])
|
|
||||||
ty[r] = compute_midrank(preds[r, m:])
|
|
||||||
tz[r] = compute_midrank(preds[r, :])
|
|
||||||
|
|
||||||
aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n)
|
|
||||||
|
|
||||||
v01 = (tz[:, :m] - tx) / float(n)
|
|
||||||
v10 = 1.0 - (tz[:, m:] - ty) / float(m)
|
|
||||||
|
|
||||||
# np.cov expects variables in rows by default when rowvar=True.
|
|
||||||
sx = np.cov(v01, rowvar=True, bias=False)
|
|
||||||
sy = np.cov(v10, rowvar=True, bias=False)
|
|
||||||
delong_cov = sx / float(m) + sy / float(n)
|
|
||||||
return aucs, delong_cov
|
|
||||||
|
|
||||||
|
|
||||||
def compute_ground_truth_statistics(ground_truth: np.ndarray) -> Tuple[np.ndarray, int]:
|
|
||||||
"""Return ordering that places positives first and label_1_count."""
|
|
||||||
y = np.asarray(ground_truth, dtype=np.int32)
|
|
||||||
if y.ndim != 1:
|
|
||||||
raise ValueError("ground_truth must be 1D")
|
|
||||||
label_1_count = int(y.sum())
|
|
||||||
order = np.argsort(-y, kind="mergesort")
|
|
||||||
return order, label_1_count
|
|
||||||
|
|
||||||
|
|
||||||
def get_auc_delong_var(healthy_scores: np.ndarray, diseased_scores: np.ndarray) -> Tuple[float, float]:
|
|
||||||
"""Compute AUC and its DeLong variance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
healthy_scores: scores for controls (label=0)
|
|
||||||
diseased_scores: scores for cases (label=1)
|
|
||||||
Returns:
|
|
||||||
(auc, auc_variance)
|
|
||||||
"""
|
|
||||||
h = np.asarray(healthy_scores, dtype=np.float64).reshape(-1)
|
|
||||||
d = np.asarray(diseased_scores, dtype=np.float64).reshape(-1)
|
|
||||||
n0 = int(h.size)
|
|
||||||
n1 = int(d.size)
|
|
||||||
if n0 == 0 or n1 == 0:
|
|
||||||
return float("nan"), float("nan")
|
|
||||||
|
|
||||||
# Arrange positives first as required by fastDeLong.
|
|
||||||
scores = np.concatenate([d, h], axis=0)
|
|
||||||
gt = np.concatenate([
|
|
||||||
np.ones((n1,), dtype=np.int32),
|
|
||||||
np.zeros((n0,), dtype=np.int32),
|
|
||||||
])
|
|
||||||
order, label_1_count = compute_ground_truth_statistics(gt)
|
|
||||||
preds_sorted = scores[order][None, :]
|
|
||||||
aucs, cov = fastDeLong(preds_sorted, label_1_count)
|
|
||||||
auc = float(aucs[0])
|
|
||||||
cov = np.asarray(cov)
|
|
||||||
var = float(cov[0, 0]) if cov.ndim == 2 else float(cov)
|
|
||||||
return auc, var
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------
|
|
||||||
# Next-token inference helper
|
|
||||||
# -------------------------
|
|
||||||
|
|
||||||
def predict_next_token_logits(
|
|
||||||
model: torch.nn.Module,
|
|
||||||
head: torch.nn.Module,
|
|
||||||
loader: DataLoader,
|
|
||||||
device: torch.device,
|
|
||||||
show_progress: bool = False,
|
|
||||||
progress_desc: str = "Inference (next-token)",
|
|
||||||
return_probs: bool = True,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Predict per-cause next-token scores at baseline positions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray of shape (N, K) where K is number of diseases (causes).
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- For loss types with time/bin dimensions (e.g., discrete-time CIF), this uses the
|
|
||||||
*first* time/bin (index 0) and drops the complement channel when present.
|
|
||||||
- If return_probs=True, applies softmax over causes for probability-like scores.
|
|
||||||
"""
|
|
||||||
model.eval()
|
|
||||||
head.eval()
|
|
||||||
|
|
||||||
all_out: List[np.ndarray] = []
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch in _progress(
|
|
||||||
loader,
|
|
||||||
enabled=show_progress,
|
|
||||||
desc=progress_desc,
|
|
||||||
total=len(loader) if hasattr(loader, "__len__") else None,
|
|
||||||
):
|
|
||||||
event_seq, time_seq, cont, cate, sex, baseline_pos = batch
|
|
||||||
event_seq = event_seq.to(device, non_blocking=True)
|
|
||||||
time_seq = time_seq.to(device, non_blocking=True)
|
|
||||||
cont = cont.to(device, non_blocking=True)
|
|
||||||
cate = cate.to(device, non_blocking=True)
|
|
||||||
sex = sex.to(device, non_blocking=True)
|
|
||||||
baseline_pos = baseline_pos.to(device, non_blocking=True)
|
|
||||||
|
|
||||||
h = model(event_seq, time_seq, sex, cont, cate)
|
|
||||||
b_idx = torch.arange(h.size(0), device=device)
|
|
||||||
c = h[b_idx, baseline_pos]
|
|
||||||
logits = head(c)
|
|
||||||
|
|
||||||
# logits can be (B, K) or (B, K, T) or (B, K+1, T)
|
|
||||||
if logits.ndim == 2:
|
|
||||||
cause_logits = logits
|
|
||||||
elif logits.ndim == 3:
|
|
||||||
# Use the first time/bin.
|
|
||||||
cause_logits = logits[..., 0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported logits shape for next-token inference: {tuple(logits.shape)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If a complement/survival channel exists (discrete-time CIF), drop it.
|
|
||||||
if hasattr(model, "n_disease"):
|
|
||||||
n_disease = int(getattr(model, "n_disease"))
|
|
||||||
if cause_logits.size(1) == n_disease + 1:
|
|
||||||
cause_logits = cause_logits[:, :n_disease]
|
|
||||||
elif cause_logits.size(1) > n_disease:
|
|
||||||
cause_logits = cause_logits[:, :n_disease]
|
|
||||||
|
|
||||||
if return_probs:
|
|
||||||
scores = torch.softmax(cause_logits, dim=1)
|
|
||||||
else:
|
|
||||||
scores = cause_logits
|
|
||||||
|
|
||||||
all_out.append(scores.detach().cpu().numpy())
|
|
||||||
|
|
||||||
return np.concatenate(all_out, axis=0) if all_out else np.zeros((0,))
|
|
||||||
Reference in New Issue
Block a user