Files
DeepHealth/evaluate_next_event.py

189 lines
5.6 KiB
Python
Raw Normal View History

import argparse
import os
from typing import List
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from utils import (
EvalRecordDataset,
build_dataset_from_config,
build_event_driven_records,
build_model_head_criterion,
eval_collate_fn,
get_test_subset,
make_inference_dataloader_kwargs,
load_checkpoint_into,
load_train_config,
parse_float_list,
predict_cifs,
roc_auc_ovr,
seed_everything,
topk_indices,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Evaluate next-event prediction using short-window CIF"
)
p.add_argument("--run_dir", type=str, required=True)
p.add_argument("--tau_short", type=float, required=True, help="years")
p.add_argument(
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
)
p.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else "cpu"),
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--min_pos",
type=int,
default=20,
help="Minimum positives for per-cause AUC",
)
return p.parse_args()
def main() -> None:
args = parse_args()
seed_everything(args.seed)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
dataset = build_dataset_from_config(cfg)
test_subset = get_test_subset(dataset, cfg)
age_bins_years = parse_float_list(args.age_bins)
records = build_event_driven_records(
dataset=dataset,
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
)
device = torch.device(args.device)
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(dataset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader(
rec_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=eval_collate_fn,
**dl_kwargs,
)
tau = float(args.tau_short)
scores = predict_cifs(model, head, criterion, loader, [tau], device=device)
# scores shape: (N,K,1) for multi-taus; squeeze last
if scores.ndim == 3:
scores = scores[:, :, 0]
n_records_total = len(records)
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
dtype=np.int64,
)
eligible = y_next >= 0
n_eligible = int(eligible.sum())
coverage = float(
n_eligible / n_records_total) if n_records_total > 0 else 0.0
metrics_rows: List[dict] = []
metrics_rows.append({"metric": "n_records_total", "value": n_records_total})
metrics_rows.append(
{"metric": "n_next_event_eligible", "value": n_eligible})
metrics_rows.append({"metric": "coverage", "value": coverage})
metrics_rows.append({"metric": "tau_short_years", "value": tau})
if n_eligible == 0:
out_path = os.path.join(run_dir, "next_event_metrics.csv")
pd.DataFrame(metrics_rows).to_csv(out_path, index=False)
print(f"No eligible records; wrote {out_path}")
return
scores_e = scores[eligible]
y_e = y_next[eligible]
pred = scores_e.argmax(axis=1)
acc = float((pred == y_e).mean())
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
# MRR
order = np.argsort(-scores_e, axis=1, kind="mergesort")
ranks = np.empty(y_e.shape[0], dtype=np.int32)
for i in range(y_e.shape[0]):
ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1
mrr = float((1.0 / ranks).mean())
metrics_rows.append({"metric": "mrr", "value": mrr})
# HitRate@K
for k in [1, 3, 5, 10, 20]:
topk = topk_indices(scores_e, k)
hit = (topk == y_e[:, None]).any(axis=1)
metrics_rows.append({"metric": f"hitrate_at_{k}",
"value": float(hit.mean())})
# Macro OvR AUC per cause (optional)
K = scores.shape[1]
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
auc = np.full((K,), np.nan, dtype=np.float64)
min_pos = int(args.min_pos)
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
for k in candidates:
auc_k = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
auc[k] = auc_k
included = (n_pos >= min_pos) & (n_neg > 0)
per_cause_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
"n_neg": n_neg,
"auc": auc,
"included": included,
}
)
aucs = auc[np.isfinite(auc)]
if aucs:
metrics_rows.append(
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
else:
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
# optional per-cause
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
per_cause_df.to_csv(out_pc, index=False)
print(f"Wrote {out_metrics}")
print(f"Wrote {out_pc}")
if __name__ == "__main__":
main()