Files
DeepHealth/evaluate_next_event.py

327 lines
10 KiB
Python
Raw Normal View History

import argparse
import os
from typing import List
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
try:
from tqdm import tqdm # noqa: F401
except Exception: # pragma: no cover
tqdm = None
from utils import (
EvalRecordDataset,
build_dataset_from_config,
build_event_driven_records,
build_model_head_criterion,
eval_collate_fn,
get_test_subset,
make_inference_dataloader_kwargs,
load_checkpoint_into,
load_train_config,
parse_float_list,
predict_cifs,
roc_auc_ovr,
seed_everything,
topk_indices,
DAYS_PER_YEAR,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Evaluate next-event prediction using short-window CIF"
)
p.add_argument("--run_dir", type=str, required=True)
p.add_argument("--tau_short", type=float, required=True, help="years")
p.add_argument(
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
)
p.add_argument(
"--device",
type=str,
default=("cuda" if torch.cuda.is_available() else "cpu"),
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--min_pos",
type=int,
default=20,
help="Minimum positives for per-cause AUC",
)
p.add_argument(
"--no_tqdm",
action="store_true",
help="Disable tqdm progress bars",
)
return p.parse_args()
def _format_age_bin_label(lo: float, hi: float) -> str:
if np.isinf(hi):
return f"[{lo}, inf)"
return f"[{lo}, {hi})"
def _compute_next_event_metrics(
*,
scores: np.ndarray,
y_next: np.ndarray,
tau_short: float,
min_pos: int,
) -> tuple[list[dict], pd.DataFrame]:
"""Compute next-event *primary* metrics on a given subset.
Implements 评估方案.md (Next-event):
- score_k = CIF_k(tau_short)
- Hit@K / MRR are computed on records with an observed next-event.
Returns (metrics_rows, diag_df). diag_df is a diagnostic per-cause AUC table
based on whether the cause occurs within (t0, t0+tau_short] (display-only).
"""
n_records_total = int(y_next.size)
eligible = y_next >= 0
n_eligible = int(eligible.sum())
coverage = float(
n_eligible / n_records_total) if n_records_total > 0 else 0.0
metrics_rows: List[dict] = []
metrics_rows.append({"metric": "n_records_total", "value": n_records_total})
metrics_rows.append(
{"metric": "n_next_event_eligible", "value": n_eligible})
metrics_rows.append({"metric": "coverage", "value": coverage})
metrics_rows.append(
{"metric": "tau_short_years", "value": float(tau_short)})
K = int(scores.shape[1])
# Diagnostic: build per-cause AUC using within-window labels.
# This is NOT a primary metric (no IPCW / censoring adjustment).
diag_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.zeros((K,), dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
if n_records_total == 0:
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
return metrics_rows, diag_df
# If no eligible, keep coverage but leave accuracy-like metrics as NaN.
if n_eligible == 0:
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
return metrics_rows, diag_df
scores_e = scores[eligible]
y_e = y_next[eligible]
pred = scores_e.argmax(axis=1)
acc = float((pred == y_e).mean())
metrics_rows.append({"metric": "hitrate_at_1", "value": acc})
# MRR
order = np.argsort(-scores_e, axis=1, kind="mergesort")
ranks = np.empty(y_e.shape[0], dtype=np.int32)
for i in range(y_e.shape[0]):
ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1
mrr = float((1.0 / ranks).mean())
metrics_rows.append({"metric": "mrr", "value": mrr})
# HitRate@K
for k in [1, 3, 5, 10, 20]:
topk = topk_indices(scores_e, k)
hit = (topk == y_e[:, None]).any(axis=1)
metrics_rows.append({"metric": f"hitrate_at_{k}",
"value": float(hit.mean())})
# Diagnostic per-cause AUC is computed outside (needs future events), so keep placeholder here.
_ = min_pos
return metrics_rows, diag_df
def _compute_within_window_auc(
*,
scores: np.ndarray,
records: list,
tau_short: float,
min_pos: int,
) -> pd.DataFrame:
"""Diagnostic-only per-cause AUC.
Label definition (event-driven, approximate; no IPCW):
y[i,k]=1 iff at least one event of cause k occurs in (t0, t0+tau_short].
"""
n_records = int(len(records))
if n_records == 0:
return pd.DataFrame(
columns=["cause_id", "n_pos", "n_neg", "auc", "included"],
)
K = int(scores.shape[1])
y = np.zeros((n_records, K), dtype=np.int8)
tau = float(tau_short)
# Build labels from future events.
for i, r in enumerate(records):
if r.future_causes.size == 0:
continue
m = r.future_dt_years <= tau
if not np.any(m):
continue
y[i, r.future_causes[m]] = 1
n_pos = y.sum(axis=0).astype(np.int64)
n_neg = (int(n_records) - n_pos).astype(np.int64)
auc = np.full((K,), np.nan, dtype=np.float64)
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0))
for k in candidates:
auc[k] = roc_auc_ovr(y[:, k].astype(np.int32),
scores[:, k].astype(np.float64))
included = (n_pos >= int(min_pos)) & (n_neg > 0)
return pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
"n_neg": n_neg,
"auc": auc,
"included": included,
}
)
def main() -> None:
args = parse_args()
seed_everything(args.seed)
show_progress = (not args.no_tqdm)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
dataset = build_dataset_from_config(cfg)
test_subset = get_test_subset(dataset, cfg)
age_bins_years = parse_float_list(args.age_bins)
records = build_event_driven_records(
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
show_progress=show_progress,
)
device = torch.device(args.device)
model, head, criterion = build_model_head_criterion(cfg, dataset, device)
load_checkpoint_into(run_dir, model, head, criterion, device)
rec_ds = EvalRecordDataset(test_subset, records)
dl_kwargs = make_inference_dataloader_kwargs(device, args.num_workers)
loader = DataLoader(
rec_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=eval_collate_fn,
**dl_kwargs,
)
tau = float(args.tau_short)
scores = predict_cifs(
model,
head,
criterion,
loader,
[tau],
device=device,
show_progress=show_progress,
progress_desc="Inference (next-event)",
)
# scores shape: (N,K,1) for multi-taus; squeeze last
if scores.ndim == 3:
scores = scores[:, :, 0]
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
dtype=np.int64,
)
# Overall (preserve existing output files/shape)
# Strict protocol: evaluate independently per age bin (no mixing).
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
age_bins_days = age_bins_years * DAYS_PER_YEAR
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
per_bin_metric_rows: List[dict] = []
per_bin_auc_parts: List[pd.DataFrame] = []
for b in range(len(age_bins_years) - 1):
lo = float(age_bins_years[b])
hi = float(age_bins_years[b + 1])
label = _format_age_bin_label(lo, hi)
m = bin_idx == b
m_scores = scores[m]
m_y = y_next[m]
m_records = [r for r, keep in zip(records, m) if bool(keep)]
m_rows, m_pc = _compute_next_event_metrics(
scores=m_scores,
y_next=m_y,
tau_short=tau,
min_pos=int(args.min_pos),
)
for row in m_rows:
per_bin_metric_rows.append({"age_bin": label, **row})
m_auc = _compute_within_window_auc(
scores=m_scores,
records=m_records,
tau_short=tau,
min_pos=int(args.min_pos),
)
m_auc.insert(0, "age_bin", label)
m_auc.insert(1, "tau_short_years", float(tau))
per_bin_auc_parts.append(m_auc)
out_metrics_bins = os.path.join(
run_dir, "next_event_metrics_by_age_bin.csv")
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv")
if per_bin_auc_parts:
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
out_auc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "tau_short_years", "cause_id", "n_pos",
"n_neg", "auc", "included"]).to_csv(out_auc_bins, index=False)
print("PRIMARY METRICS: Hit@K / MRR are reported per age bin.")
print("DIAGNOSTICS ONLY: AUC table is event-driven approximate (no IPCW).")
print(f"Wrote {out_metrics_bins}")
print(f"Wrote {out_auc_bins}")
if __name__ == "__main__":
main()