Files
DeepHealth/evaluate_next_event.py

309 lines
9.8 KiB
Python

import argparse
import os
from typing import List
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
try:
from tqdm import tqdm # noqa: F401
except Exception: # pragma: no cover
tqdm = None
from utils import (
EvalRecordDataset,
build_dataset_from_config,
build_event_driven_records,
build_model_head_criterion,
eval_collate_fn,
get_test_subset,
make_inference_dataloader_kwargs,
load_checkpoint_into,
load_train_config,
parse_float_list,
predict_cifs,
roc_auc_ovr,
seed_everything,
topk_indices,
DAYS_PER_YEAR,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Evaluate next-event prediction using short-window CIF"
)
p.add_argument("--run_dir", type=str, required=True)
p.add_argument("--tau_short", type=float, required=True, help="years")
p.add_argument(
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
)
p.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else "cpu"),
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--min_pos",
type=int,
default=20,
help="Minimum positives for per-cause AUC",
)
p.add_argument(
"--no_tqdm",
action="store_true",
help="Disable tqdm progress bars",
)
return p.parse_args()
def _format_age_bin_label(lo: float, hi: float) -> str:
if np.isinf(hi):
return f"[{lo}, inf)"
return f"[{lo}, {hi})"
def _compute_next_event_metrics(
*,
scores: np.ndarray,
y_next: np.ndarray,
tau_short: float,
min_pos: int,
) -> tuple[list[dict], pd.DataFrame]:
"""Compute next-event metrics on a given subset.
Definitions are unchanged from the original script.
"""
n_records_total = int(y_next.size)
eligible = y_next >= 0
n_eligible = int(eligible.sum())
coverage = float(
n_eligible / n_records_total) if n_records_total > 0 else 0.0
metrics_rows: List[dict] = []
metrics_rows.append({"metric": "n_records_total", "value": n_records_total})
metrics_rows.append(
{"metric": "n_next_event_eligible", "value": n_eligible})
metrics_rows.append({"metric": "coverage", "value": coverage})
metrics_rows.append(
{"metric": "tau_short_years", "value": float(tau_short)})
K = int(scores.shape[1])
if n_records_total == 0:
per_cause_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.zeros((K,), dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
# If no eligible, keep coverage but leave accuracy-like metrics as NaN.
if n_eligible == 0:
per_cause_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.full((K,), n_records_total, dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
scores_e = scores[eligible]
y_e = y_next[eligible]
pred = scores_e.argmax(axis=1)
acc = float((pred == y_e).mean())
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
# MRR
order = np.argsort(-scores_e, axis=1, kind="mergesort")
ranks = np.empty(y_e.shape[0], dtype=np.int32)
for i in range(y_e.shape[0]):
ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1
mrr = float((1.0 / ranks).mean())
metrics_rows.append({"metric": "mrr", "value": mrr})
# HitRate@K
for k in [1, 3, 5, 10, 20]:
topk = topk_indices(scores_e, k)
hit = (topk == y_e[:, None]).any(axis=1)
metrics_rows.append({"metric": f"hitrate_at_{k}",
"value": float(hit.mean())})
# Macro OvR AUC per cause (optional)
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
auc = np.full((K,), np.nan, dtype=np.float64)
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0))
for k in candidates:
auc[k] = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
included = (n_pos >= int(min_pos)) & (n_neg > 0)
per_cause_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
"n_neg": n_neg,
"auc": auc,
"included": included,
}
)
aucs = auc[np.isfinite(auc)]
if aucs.size > 0:
metrics_rows.append(
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
else:
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
def main() -> None:
args = parse_args()
seed_everything(args.seed)
show_progress = (not args.no_tqdm)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
dataset = build_dataset_from_config(cfg)
test_subset = get_test_subset(dataset, cfg)
age_bins_years = parse_float_list(args.age_bins)
records = build_event_driven_records(
dataset=dataset,
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
show_progress=show_progress,
)
device = torch.device(args.device)
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(dataset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader(
rec_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=eval_collate_fn,
**dl_kwargs,
)
tau = float(args.tau_short)
scores = predict_cifs(
model,
head,
criterion,
loader,
[tau],
device=device,
show_progress=show_progress,
progress_desc="Inference (next-event)",
)
# scores shape: (N,K,1) for multi-taus; squeeze last
if scores.ndim == 3:
scores = scores[:, :, 0]
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
dtype=np.int64,
)
# Overall (preserve existing output files/shape)
metrics_rows, per_cause_df = _compute_next_event_metrics(
scores=scores,
y_next=y_next,
tau_short=tau,
min_pos=int(args.min_pos),
)
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
per_cause_df.to_csv(out_pc, index=False)
# By age bin (new outputs)
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
age_bins_days = age_bins_years * DAYS_PER_YEAR
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
per_bin_metric_rows: List[dict] = []
per_bin_cause_parts: List[pd.DataFrame] = []
for b in range(len(age_bins_years) - 1):
lo = float(age_bins_years[b])
hi = float(age_bins_years[b + 1])
label = _format_age_bin_label(lo, hi)
m = bin_idx == b
m_scores = scores[m]
m_y = y_next[m]
m_rows, m_pc = _compute_next_event_metrics(
scores=m_scores,
y_next=m_y,
tau_short=tau,
min_pos=int(args.min_pos),
)
for row in m_rows:
per_bin_metric_rows.append({"age_bin": label, **row})
m_pc = m_pc.copy()
m_pc.insert(0, "age_bin", label)
per_bin_cause_parts.append(m_pc)
out_metrics_bins = os.path.join(
run_dir, "next_event_metrics_by_age_bin.csv")
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
out_pc_bins = os.path.join(run_dir, "next_event_per_cause_by_age_bin.csv")
if per_bin_cause_parts:
pd.concat(per_bin_cause_parts, ignore_index=True).to_csv(
out_pc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "cause_id", "n_pos", "n_neg",
"auc", "included"]).to_csv(out_pc_bins, index=False)
print(f"Wrote {out_metrics}")
print(f"Wrote {out_pc}")
print(f"Wrote {out_metrics_bins}")
print(f"Wrote {out_pc_bins}")
if __name__ == "__main__":
main()