Files
DeepHealth/evaluate_next_event.py

309 lines
9.8 KiB
Python
Raw Normal View History

import argparse
import os
from typing import List
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
try:
from tqdm import tqdm # noqa: F401
except Exception: # pragma: no cover
tqdm = None
from utils import (
EvalRecordDataset,
build_dataset_from_config,
build_event_driven_records,
build_model_head_criterion,
eval_collate_fn,
get_test_subset,
make_inference_dataloader_kwargs,
load_checkpoint_into,
load_train_config,
parse_float_list,
predict_cifs,
roc_auc_ovr,
seed_everything,
topk_indices,
DAYS_PER_YEAR,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Evaluate next-event prediction using short-window CIF"
)
p.add_argument("--run_dir", type=str, required=True)
p.add_argument("--tau_short", type=float, required=True, help="years")
p.add_argument(
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
)
p.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else "cpu"),
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--min_pos",
type=int,
default=20,
help="Minimum positives for per-cause AUC",
)
p.add_argument(
"--no_tqdm",
action="store_true",
help="Disable tqdm progress bars",
)
return p.parse_args()
def _format_age_bin_label(lo: float, hi: float) -> str:
if np.isinf(hi):
return f"[{lo}, inf)"
return f"[{lo}, {hi})"
def _compute_next_event_metrics(
*,
scores: np.ndarray,
y_next: np.ndarray,
tau_short: float,
min_pos: int,
) -> tuple[list[dict], pd.DataFrame]:
"""Compute next-event metrics on a given subset.
Definitions are unchanged from the original script.
"""
n_records_total = int(y_next.size)
eligible = y_next >= 0
n_eligible = int(eligible.sum())
coverage = float(
n_eligible / n_records_total) if n_records_total > 0 else 0.0
metrics_rows: List[dict] = []
metrics_rows.append({"metric": "n_records_total", "value": n_records_total})
metrics_rows.append(
{"metric": "n_next_event_eligible", "value": n_eligible})
metrics_rows.append({"metric": "coverage", "value": coverage})
metrics_rows.append(
{"metric": "tau_short_years", "value": float(tau_short)})
K = int(scores.shape[1])
if n_records_total == 0:
per_cause_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.zeros((K,), dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
# If no eligible, keep coverage but leave accuracy-like metrics as NaN.
if n_eligible == 0:
per_cause_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.full((K,), n_records_total, dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
scores_e = scores[eligible]
y_e = y_next[eligible]
pred = scores_e.argmax(axis=1)
acc = float((pred == y_e).mean())
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
# MRR
order = np.argsort(-scores_e, axis=1, kind="mergesort")
ranks = np.empty(y_e.shape[0], dtype=np.int32)
for i in range(y_e.shape[0]):
ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1
mrr = float((1.0 / ranks).mean())
metrics_rows.append({"metric": "mrr", "value": mrr})
# HitRate@K
for k in [1, 3, 5, 10, 20]:
topk = topk_indices(scores_e, k)
hit = (topk == y_e[:, None]).any(axis=1)
metrics_rows.append({"metric": f"hitrate_at_{k}",
"value": float(hit.mean())})
# Macro OvR AUC per cause (optional)
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
auc = np.full((K,), np.nan, dtype=np.float64)
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0))
for k in candidates:
auc[k] = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
included = (n_pos >= int(min_pos)) & (n_neg > 0)
per_cause_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
"n_neg": n_neg,
"auc": auc,
"included": included,
}
)
aucs = auc[np.isfinite(auc)]
if aucs.size > 0:
metrics_rows.append(
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
else:
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
def main() -> None:
args = parse_args()
seed_everything(args.seed)
show_progress = (not args.no_tqdm)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
dataset = build_dataset_from_config(cfg)
test_subset = get_test_subset(dataset, cfg)
age_bins_years = parse_float_list(args.age_bins)
records = build_event_driven_records(
dataset=dataset,
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
show_progress=show_progress,
)
device = torch.device(args.device)
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(dataset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader(
rec_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=eval_collate_fn,
**dl_kwargs,
)
tau = float(args.tau_short)
scores = predict_cifs(
model,
head,
criterion,
loader,
[tau],
device=device,
show_progress=show_progress,
progress_desc="Inference (next-event)",
)
# scores shape: (N,K,1) for multi-taus; squeeze last
if scores.ndim == 3:
scores = scores[:, :, 0]
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
dtype=np.int64,
)
# Overall (preserve existing output files/shape)
metrics_rows, per_cause_df = _compute_next_event_metrics(
scores=scores,
y_next=y_next,
tau_short=tau,
min_pos=int(args.min_pos),
)
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
per_cause_df.to_csv(out_pc, index=False)
# By age bin (new outputs)
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
age_bins_days = age_bins_years * DAYS_PER_YEAR
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
per_bin_metric_rows: List[dict] = []
per_bin_cause_parts: List[pd.DataFrame] = []
for b in range(len(age_bins_years) - 1):
lo = float(age_bins_years[b])
hi = float(age_bins_years[b + 1])
label = _format_age_bin_label(lo, hi)
m = bin_idx == b
m_scores = scores[m]
m_y = y_next[m]
m_rows, m_pc = _compute_next_event_metrics(
scores=m_scores,
y_next=m_y,
tau_short=tau,
min_pos=int(args.min_pos),
)
for row in m_rows:
per_bin_metric_rows.append({"age_bin": label, **row})
m_pc = m_pc.copy()
m_pc.insert(0, "age_bin", label)
per_bin_cause_parts.append(m_pc)
out_metrics_bins = os.path.join(
run_dir, "next_event_metrics_by_age_bin.csv")
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
out_pc_bins = os.path.join(run_dir, "next_event_per_cause_by_age_bin.csv")
if per_bin_cause_parts:
pd.concat(per_bin_cause_parts, ignore_index=True).to_csv(
out_pc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "cause_id", "n_pos", "n_neg",
"auc", "included"]).to_csv(out_pc_bins, index=False)
print(f"Wrote {out_metrics}")
print(f"Wrote {out_pc}")
print(f"Wrote {out_metrics_bins}")
print(f"Wrote {out_pc_bins}")
if __name__ == "__main__":
main()