Enhance next-event evaluation with age-bin metrics and diagnostic AUC outputs
This commit is contained in:
@@ -1,18 +1,25 @@
|
|||||||
"""Horizon-capture evaluation.
|
"""Horizon-capture evaluation (event-driven, age-stratified).
|
||||||
|
|
||||||
DISCLAIMERS (important):
|
This script implements the protocol described in 评估方案.md:
|
||||||
- The reported AUC is "time-dependent" only because the label depends on the chosen horizon $\tau$.
|
|
||||||
Without explicit censoring / follow-up end times, this is NOT a classical risk-set AUC with IPCW.
|
|
||||||
Use it for model comparison and diagnostics, not strict statistical interpretation.
|
|
||||||
|
|
||||||
- The reported Brier scores are unadjusted diagnostic/proxy metrics (no censoring adjustment).
|
- Age-stratified evaluation: metrics are computed independently within each age bin (no mixing).
|
||||||
Use them to detect probability-mass compression / numerical stability issues; do not claim
|
- Event-driven inclusion: each (person, age_bin) yields a record iff DOA <= bin upper bound and
|
||||||
calibrated absolute risk.
|
there is at least one disease event in the bin; baseline t0 is sampled randomly from in-bin
|
||||||
|
disease events with t0 >= DOA.
|
||||||
|
- No follow-up completeness filtering (no t0+tau <= t_end constraint).
|
||||||
|
|
||||||
|
Primary outputs per age bin:
|
||||||
|
- Top-K Event Capture@tau (event-count based)
|
||||||
|
- Workload–Yield curves (Top-p% people by a person-level horizon score)
|
||||||
|
|
||||||
|
Secondary (diagnostic-only) outputs per age bin:
|
||||||
|
- Approximate event-driven AUC / Brier (no IPCW, no censoring adjustment)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Sequence
|
from typing import Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -40,6 +47,7 @@ from utils import (
|
|||||||
roc_auc_ovr,
|
roc_auc_ovr,
|
||||||
seed_everything,
|
seed_everything,
|
||||||
topk_indices,
|
topk_indices,
|
||||||
|
DAYS_PER_YEAR,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -58,8 +66,8 @@ def parse_args() -> argparse.Namespace:
|
|||||||
"--age_bins",
|
"--age_bins",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
|
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
|
||||||
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
|
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
|
||||||
)
|
)
|
||||||
|
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
@@ -75,7 +83,14 @@ def parse_args() -> argparse.Namespace:
|
|||||||
"--topk_list",
|
"--topk_list",
|
||||||
type=int,
|
type=int,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=[1, 5, 10, 20, 50],
|
default=[5, 10, 20, 50],
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--workload_fracs",
|
||||||
|
type=float,
|
||||||
|
nargs="+",
|
||||||
|
default=[0.01, 0.02, 0.05, 0.1, 0.2, 0.5],
|
||||||
|
help="Fractions for workload–yield curves (Top-p%% people).",
|
||||||
)
|
)
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
"--no_tqdm",
|
"--no_tqdm",
|
||||||
@@ -85,6 +100,33 @@ def parse_args() -> argparse.Namespace:
|
|||||||
return p.parse_args()
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _format_age_bin_label(lo: float, hi: float) -> str:
|
||||||
|
if np.isinf(hi):
|
||||||
|
return f"[{lo}, inf)"
|
||||||
|
return f"[{lo}, {hi})"
|
||||||
|
|
||||||
|
|
||||||
|
def _assign_age_bin_idx(t0_days: np.ndarray, age_bins_years: Sequence[float]) -> np.ndarray:
|
||||||
|
age_bins_years = np.asarray(list(age_bins_years), dtype=np.float64)
|
||||||
|
age_bins_days = age_bins_years * DAYS_PER_YEAR
|
||||||
|
return np.searchsorted(age_bins_days, t0_days, side="left") - 1
|
||||||
|
|
||||||
|
|
||||||
|
def _event_counts_within_tau(
|
||||||
|
n_records: int,
|
||||||
|
event_record_idx: np.ndarray,
|
||||||
|
event_dt_years: np.ndarray,
|
||||||
|
tau_years: float,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Count events within (t0, t0+tau] per record (event-count, not unique causes)."""
|
||||||
|
if event_record_idx.size == 0:
|
||||||
|
return np.zeros((n_records,), dtype=np.int64)
|
||||||
|
m = event_dt_years <= float(tau_years)
|
||||||
|
if not np.any(m):
|
||||||
|
return np.zeros((n_records,), dtype=np.int64)
|
||||||
|
return np.bincount(event_record_idx[m], minlength=n_records).astype(np.int64)
|
||||||
|
|
||||||
|
|
||||||
def build_labels_within_tau_flat(
|
def build_labels_within_tau_flat(
|
||||||
n_records: int,
|
n_records: int,
|
||||||
n_causes: int,
|
n_causes: int,
|
||||||
@@ -148,8 +190,8 @@ def main() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Print disclaimers every run (requested)
|
# Print disclaimers every run (requested)
|
||||||
print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).")
|
print("PRIMARY METRICS: event-count Capture@K and Workload–Yield, computed independently per age bin.")
|
||||||
print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).")
|
print("DIAGNOSTICS ONLY: AUC/Brier below are event-driven approximations (no IPCW / censoring adjustment).")
|
||||||
|
|
||||||
scores = predict_cifs(
|
scores = predict_cifs(
|
||||||
model,
|
model,
|
||||||
@@ -171,131 +213,235 @@ def main() -> None:
|
|||||||
raise ValueError("Record count mismatch")
|
raise ValueError("Record count mismatch")
|
||||||
|
|
||||||
# Pre-flatten all future events once to avoid repeated per-record scans.
|
# Pre-flatten all future events once to avoid repeated per-record scans.
|
||||||
|
# NOTE: these are event-level arrays (not unique causes), suitable for event-count Capture@K.
|
||||||
evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K)
|
evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K)
|
||||||
|
|
||||||
per_tau_rows: List[Dict[str, object]] = []
|
# Assign each record to an age bin (based on t0; by construction t0 is within the bin).
|
||||||
per_cause_rows: List[Dict[str, object]] = []
|
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
|
||||||
|
bin_idx = _assign_age_bin_idx(t0_days, age_bins_years)
|
||||||
|
age_bins_years_arr = np.asarray(list(age_bins_years), dtype=np.float64)
|
||||||
|
|
||||||
|
capture_rows: List[Dict[str, object]] = []
|
||||||
workload_rows: List[Dict[str, object]] = []
|
workload_rows: List[Dict[str, object]] = []
|
||||||
|
|
||||||
horizon_iter = enumerate(horizons)
|
# Diagnostics (optional): approximate event-driven AUC/Brier computed per bin.
|
||||||
|
diag_rows: List[Dict[str, object]] = []
|
||||||
|
diag_per_cause_parts: List[pd.DataFrame] = []
|
||||||
|
|
||||||
|
bins_iter = range(len(age_bins_years_arr) - 1)
|
||||||
if show_progress and tqdm is not None:
|
if show_progress and tqdm is not None:
|
||||||
horizon_iter = tqdm(horizon_iter, total=len(
|
bins_iter = tqdm(bins_iter, total=len(
|
||||||
horizons), desc="Metrics by horizon")
|
age_bins_years_arr) - 1, desc="Age bins")
|
||||||
|
|
||||||
for h_idx, tau in horizon_iter:
|
for b in bins_iter:
|
||||||
s_tau = scores[:, :, h_idx]
|
lo = float(age_bins_years_arr[b])
|
||||||
y_tau = build_labels_within_tau_flat(
|
hi = float(age_bins_years_arr[b + 1])
|
||||||
N, K, evt_rec_idx, evt_cause, evt_dt, tau)
|
age_label = _format_age_bin_label(lo, hi)
|
||||||
|
|
||||||
# Per-cause counts + Brier (vectorized)
|
m_rec = bin_idx == b
|
||||||
n_pos = y_tau.sum(axis=0).astype(np.int64)
|
n_bin = int(m_rec.sum())
|
||||||
n_neg = (int(N) - n_pos).astype(np.int64)
|
if n_bin == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
# Brier per cause: mean_i (y - s)^2
|
rec_idx_bin = np.flatnonzero(m_rec).astype(np.int32)
|
||||||
brier_per_cause = np.mean(
|
|
||||||
(y_tau.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0)
|
|
||||||
brier_macro = float(np.mean(brier_per_cause)) if K > 0 else float("nan")
|
|
||||||
brier_weighted = float(np.sum(
|
|
||||||
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
|
|
||||||
|
|
||||||
# AUC: compute only for causes with enough positives and at least 1 negative
|
# Filter events to this bin's records once.
|
||||||
auc = np.full((K,), np.nan, dtype=np.float64)
|
m_evt_bin = m_rec[evt_rec_idx] if evt_rec_idx.size > 0 else np.zeros(
|
||||||
min_pos = int(args.min_pos)
|
(0,), dtype=bool)
|
||||||
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
|
evt_rec_idx_b = evt_rec_idx[m_evt_bin]
|
||||||
for k in candidates:
|
evt_cause_b = evt_cause[m_evt_bin]
|
||||||
auc[k] = roc_auc_ovr(y_tau[:, k].astype(
|
evt_dt_b = evt_dt[m_evt_bin]
|
||||||
np.int32), s_tau[:, k].astype(np.float64))
|
|
||||||
|
|
||||||
finite_auc = auc[np.isfinite(auc)]
|
horizon_iter = enumerate(horizons)
|
||||||
auc_macro = float(np.mean(finite_auc)
|
if show_progress and tqdm is not None:
|
||||||
) if finite_auc.size > 0 else float("nan")
|
horizon_iter = tqdm(horizon_iter, total=len(
|
||||||
w_mask = np.isfinite(auc)
|
horizons), desc=f"Horizons {age_label}")
|
||||||
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
|
|
||||||
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
|
|
||||||
n_valid_auc = int(np.isfinite(auc).sum())
|
|
||||||
|
|
||||||
# Append per-cause rows (vectorized via DataFrame to avoid Python loops)
|
# Precompute a local index mapping for diagnostics label building.
|
||||||
per_cause_rows.append(
|
local_map = np.full((N,), -1, dtype=np.int32)
|
||||||
pd.DataFrame(
|
local_map[rec_idx_bin] = np.arange(n_bin, dtype=np.int32)
|
||||||
{
|
|
||||||
"tau_years": float(tau),
|
for h_idx, tau in horizon_iter:
|
||||||
"cause_id": np.arange(K, dtype=np.int64),
|
s_tau_all = scores[:, :, h_idx]
|
||||||
"n_pos": n_pos,
|
s_tau = s_tau_all[m_rec]
|
||||||
"n_neg": n_neg,
|
|
||||||
"auc": auc,
|
# -------------------------
|
||||||
"brier": brier_per_cause,
|
# Primary metric: Top-K Event Capture@tau (event-count based)
|
||||||
}
|
# -------------------------
|
||||||
|
denom_events = int(np.sum(evt_dt_b <= float(tau))
|
||||||
|
) if evt_dt_b.size > 0 else 0
|
||||||
|
if denom_events == 0:
|
||||||
|
for topk in args.topk_list:
|
||||||
|
capture_rows.append(
|
||||||
|
{
|
||||||
|
"age_bin": age_label,
|
||||||
|
"tau_years": float(tau),
|
||||||
|
"topk": int(topk),
|
||||||
|
"capture_at_k": float("nan"),
|
||||||
|
"denom_events": int(0),
|
||||||
|
"numer_events": int(0),
|
||||||
|
"n_records": int(n_bin),
|
||||||
|
"n_causes": int(K),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
m_evt_tau = evt_dt_b <= float(tau)
|
||||||
|
evt_rec_idx_tau = evt_rec_idx_b[m_evt_tau]
|
||||||
|
evt_cause_tau = evt_cause_b[m_evt_tau]
|
||||||
|
|
||||||
|
# For each K, compute whether each event's cause is in that record's Top-K list.
|
||||||
|
for topk in args.topk_list:
|
||||||
|
topk = int(topk)
|
||||||
|
idx = topk_indices(s_tau_all, topk) # shape (N, topk)
|
||||||
|
idx_for_events = idx[evt_rec_idx_tau]
|
||||||
|
hits = (idx_for_events ==
|
||||||
|
evt_cause_tau[:, None]).any(axis=1)
|
||||||
|
numer_events = int(hits.sum())
|
||||||
|
capture = float(numer_events / denom_events)
|
||||||
|
capture_rows.append(
|
||||||
|
{
|
||||||
|
"age_bin": age_label,
|
||||||
|
"tau_years": float(tau),
|
||||||
|
"topk": int(topk),
|
||||||
|
"capture_at_k": capture,
|
||||||
|
"denom_events": int(denom_events),
|
||||||
|
"numer_events": int(numer_events),
|
||||||
|
"n_records": int(n_bin),
|
||||||
|
"n_causes": int(K),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Primary metric: Workload–Yield (Top-p% people)
|
||||||
|
# -------------------------
|
||||||
|
# Person-level score: max_k CIF_k(tau). This is used only for workload–yield ranking.
|
||||||
|
person_score = s_tau.max(axis=1) if K > 0 else np.zeros(
|
||||||
|
(n_bin,), dtype=np.float64)
|
||||||
|
order = np.argsort(-person_score, kind="mergesort")
|
||||||
|
|
||||||
|
counts_per_record = _event_counts_within_tau(
|
||||||
|
n_bin, local_map[evt_rec_idx_b], evt_dt_b, tau)
|
||||||
|
total_events = int(counts_per_record.sum())
|
||||||
|
overall_events_per_person = (
|
||||||
|
total_events / float(n_bin)) if n_bin > 0 else float("nan")
|
||||||
|
|
||||||
|
for frac in args.workload_fracs:
|
||||||
|
frac = float(frac)
|
||||||
|
if frac <= 0.0:
|
||||||
|
continue
|
||||||
|
n_sel = int(math.ceil(frac * n_bin))
|
||||||
|
n_sel = min(max(n_sel, 1), n_bin)
|
||||||
|
sel_local = order[:n_sel]
|
||||||
|
events_captured = int(counts_per_record[sel_local].sum())
|
||||||
|
capture_rate = float(
|
||||||
|
events_captured / total_events) if total_events > 0 else float("nan")
|
||||||
|
|
||||||
|
selected_events_per_person = (
|
||||||
|
events_captured / float(n_sel)) if n_sel > 0 else float("nan")
|
||||||
|
lift = (selected_events_per_person /
|
||||||
|
overall_events_per_person) if overall_events_per_person > 0 else float("nan")
|
||||||
|
|
||||||
|
workload_rows.append(
|
||||||
|
{
|
||||||
|
"age_bin": age_label,
|
||||||
|
"tau_years": float(tau),
|
||||||
|
"frac_selected": float(frac),
|
||||||
|
"n_selected": int(n_sel),
|
||||||
|
"n_records": int(n_bin),
|
||||||
|
"total_events": int(total_events),
|
||||||
|
"events_captured": int(events_captured),
|
||||||
|
"capture_rate": capture_rate,
|
||||||
|
"lift_events_per_person": float(lift),
|
||||||
|
"person_score_def": "max_k_CIF_k(tau)",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# -------------------------
|
||||||
|
# Diagnostics (optional): approximate event-driven AUC/Brier
|
||||||
|
# -------------------------
|
||||||
|
# Convert event-level data to binary labels y[i,k]=1 iff >=1 event of cause k within tau.
|
||||||
|
y_tau_bin = np.zeros((n_bin, K), dtype=np.int8)
|
||||||
|
if evt_dt_b.size > 0:
|
||||||
|
m_evt_tau = evt_dt_b <= float(tau)
|
||||||
|
if np.any(m_evt_tau):
|
||||||
|
rec_local = local_map[evt_rec_idx_b[m_evt_tau]]
|
||||||
|
valid = rec_local >= 0
|
||||||
|
y_tau_bin[rec_local[valid],
|
||||||
|
evt_cause_b[m_evt_tau][valid]] = 1
|
||||||
|
|
||||||
|
n_pos = y_tau_bin.sum(axis=0).astype(np.int64)
|
||||||
|
n_neg = (int(n_bin) - n_pos).astype(np.int64)
|
||||||
|
|
||||||
|
brier_per_cause = np.mean(
|
||||||
|
(y_tau_bin.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0
|
||||||
)
|
)
|
||||||
)
|
brier_macro = float(np.mean(brier_per_cause)
|
||||||
|
) if K > 0 else float("nan")
|
||||||
|
brier_weighted = float(np.sum(
|
||||||
|
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
|
||||||
|
|
||||||
# Business metrics for each topK
|
auc = np.full((K,), np.nan, dtype=np.float64)
|
||||||
denom_true_pairs = int(y_tau.sum())
|
min_pos = int(args.min_pos)
|
||||||
for topk in args.topk_list:
|
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
|
||||||
topk = int(topk)
|
for k in candidates:
|
||||||
idx = topk_indices(s_tau, topk)
|
auc[k] = roc_auc_ovr(y_tau_bin[:, k].astype(
|
||||||
captured = np.take_along_axis(y_tau, idx, axis=1)
|
np.int32), s_tau[:, k].astype(np.float64))
|
||||||
hits = captured.sum(axis=1).astype(np.float64)
|
|
||||||
true_cnt = y_tau.sum(axis=1).astype(np.float64)
|
|
||||||
|
|
||||||
precision_like = hits / float(min(topk, K))
|
finite_auc = auc[np.isfinite(auc)]
|
||||||
mean_precision = float(np.mean(precision_like)
|
auc_macro = float(np.mean(finite_auc)
|
||||||
) if N > 0 else float("nan")
|
) if finite_auc.size > 0 else float("nan")
|
||||||
|
w_mask = np.isfinite(auc)
|
||||||
|
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
|
||||||
|
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
|
||||||
|
n_valid_auc = int(np.isfinite(auc).sum())
|
||||||
|
|
||||||
mask_has_true = true_cnt > 0
|
diag_rows.append(
|
||||||
recall_like = np.full((N,), np.nan, dtype=np.float64)
|
|
||||||
recall_like[mask_has_true] = hits[mask_has_true] / \
|
|
||||||
true_cnt[mask_has_true]
|
|
||||||
mean_recall = float(np.nanmean(recall_like)) if np.any(
|
|
||||||
mask_has_true) else float("nan")
|
|
||||||
median_recall = float(np.nanmedian(recall_like)) if np.any(
|
|
||||||
mask_has_true) else float("nan")
|
|
||||||
|
|
||||||
numer_captured_pairs = int(captured.sum())
|
|
||||||
pop_capture_rate = float(
|
|
||||||
numer_captured_pairs / denom_true_pairs) if denom_true_pairs > 0 else float("nan")
|
|
||||||
|
|
||||||
workload_rows.append(
|
|
||||||
{
|
{
|
||||||
|
"age_bin": age_label,
|
||||||
"tau_years": float(tau),
|
"tau_years": float(tau),
|
||||||
"topk": int(topk),
|
"n_records": int(n_bin),
|
||||||
"population_capture_rate": pop_capture_rate,
|
"n_causes": int(K),
|
||||||
"mean_precision_like": mean_precision,
|
"auc_macro": auc_macro,
|
||||||
"mean_recall_like": mean_recall,
|
"auc_weighted_by_npos": auc_weighted,
|
||||||
"median_recall_like": median_recall,
|
"n_causes_valid_auc": int(n_valid_auc),
|
||||||
"denom_true_pairs": denom_true_pairs,
|
"brier_macro": brier_macro,
|
||||||
"numer_captured_pairs": numer_captured_pairs,
|
"brier_weighted_by_npos": brier_weighted,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
per_tau_rows.append(
|
diag_per_cause_parts.append(
|
||||||
{
|
pd.DataFrame(
|
||||||
"tau_years": float(tau),
|
{
|
||||||
"n_records": int(N),
|
"age_bin": age_label,
|
||||||
"n_causes": int(K),
|
"tau_years": float(tau),
|
||||||
"auc_macro": auc_macro,
|
"cause_id": np.arange(K, dtype=np.int64),
|
||||||
"auc_weighted_by_npos": auc_weighted,
|
"n_pos": n_pos,
|
||||||
"n_causes_valid_auc": int(n_valid_auc),
|
"n_neg": n_neg,
|
||||||
"brier_macro": brier_macro,
|
"auc": auc,
|
||||||
"brier_weighted_by_npos": brier_weighted,
|
"brier": brier_per_cause,
|
||||||
"total_true_pairs": denom_true_pairs,
|
}
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
out_metrics = os.path.join(run_dir, "horizon_metrics.csv")
|
out_capture = os.path.join(run_dir, "horizon_capture.csv")
|
||||||
out_pc = os.path.join(run_dir, "horizon_per_cause.csv")
|
|
||||||
out_wy = os.path.join(run_dir, "workload_yield.csv")
|
out_wy = os.path.join(run_dir, "workload_yield.csv")
|
||||||
|
out_diag = os.path.join(run_dir, "horizon_metrics.csv")
|
||||||
|
out_diag_pc = os.path.join(run_dir, "horizon_per_cause.csv")
|
||||||
|
|
||||||
pd.DataFrame(per_tau_rows).to_csv(out_metrics, index=False)
|
pd.DataFrame(capture_rows).to_csv(out_capture, index=False)
|
||||||
if per_cause_rows:
|
|
||||||
pd.concat(per_cause_rows, ignore_index=True).to_csv(out_pc, index=False)
|
|
||||||
else:
|
|
||||||
pd.DataFrame(columns=["tau_years", "cause_id", "n_pos",
|
|
||||||
"n_neg", "auc", "brier"]).to_csv(out_pc, index=False)
|
|
||||||
pd.DataFrame(workload_rows).to_csv(out_wy, index=False)
|
pd.DataFrame(workload_rows).to_csv(out_wy, index=False)
|
||||||
|
pd.DataFrame(diag_rows).to_csv(out_diag, index=False)
|
||||||
|
if diag_per_cause_parts:
|
||||||
|
pd.concat(diag_per_cause_parts, ignore_index=True).to_csv(
|
||||||
|
out_diag_pc, index=False)
|
||||||
|
else:
|
||||||
|
pd.DataFrame(columns=["age_bin", "tau_years", "cause_id", "n_pos",
|
||||||
|
"n_neg", "auc", "brier"]).to_csv(out_diag_pc, index=False)
|
||||||
|
|
||||||
print(f"Wrote {out_metrics}")
|
print(f"Wrote {out_capture}")
|
||||||
print(f"Wrote {out_pc}")
|
|
||||||
print(f"Wrote {out_wy}")
|
print(f"Wrote {out_wy}")
|
||||||
|
print(f"Wrote {out_diag} (diagnostic-only)")
|
||||||
|
print(f"Wrote {out_diag_pc} (diagnostic-only)")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ def parse_args() -> argparse.Namespace:
|
|||||||
"--age_bins",
|
"--age_bins",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
|
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
|
||||||
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
|
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
|
||||||
)
|
)
|
||||||
|
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
@@ -80,9 +80,14 @@ def _compute_next_event_metrics(
|
|||||||
tau_short: float,
|
tau_short: float,
|
||||||
min_pos: int,
|
min_pos: int,
|
||||||
) -> tuple[list[dict], pd.DataFrame]:
|
) -> tuple[list[dict], pd.DataFrame]:
|
||||||
"""Compute next-event metrics on a given subset.
|
"""Compute next-event *primary* metrics on a given subset.
|
||||||
|
|
||||||
Definitions are unchanged from the original script.
|
Implements 评估方案.md (Next-event):
|
||||||
|
- score_k = CIF_k(tau_short)
|
||||||
|
- Hit@K / MRR are computed on records with an observed next-event.
|
||||||
|
|
||||||
|
Returns (metrics_rows, diag_df). diag_df is a diagnostic per-cause AUC table
|
||||||
|
based on whether the cause occurs within (t0, t0+tau_short] (display-only).
|
||||||
"""
|
"""
|
||||||
n_records_total = int(y_next.size)
|
n_records_total = int(y_next.size)
|
||||||
eligible = y_next >= 0
|
eligible = y_next >= 0
|
||||||
@@ -99,49 +104,40 @@ def _compute_next_event_metrics(
|
|||||||
{"metric": "tau_short_years", "value": float(tau_short)})
|
{"metric": "tau_short_years", "value": float(tau_short)})
|
||||||
|
|
||||||
K = int(scores.shape[1])
|
K = int(scores.shape[1])
|
||||||
|
# Diagnostic: build per-cause AUC using within-window labels.
|
||||||
|
# This is NOT a primary metric (no IPCW / censoring adjustment).
|
||||||
|
diag_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"cause_id": np.arange(K, dtype=np.int64),
|
||||||
|
"n_pos": np.zeros((K,), dtype=np.int64),
|
||||||
|
"n_neg": np.zeros((K,), dtype=np.int64),
|
||||||
|
"auc": np.full((K,), np.nan, dtype=np.float64),
|
||||||
|
"included": np.zeros((K,), dtype=bool),
|
||||||
|
}
|
||||||
|
)
|
||||||
if n_records_total == 0:
|
if n_records_total == 0:
|
||||||
per_cause_df = pd.DataFrame(
|
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
|
||||||
{
|
|
||||||
"cause_id": np.arange(K, dtype=np.int64),
|
|
||||||
"n_pos": np.zeros((K,), dtype=np.int64),
|
|
||||||
"n_neg": np.zeros((K,), dtype=np.int64),
|
|
||||||
"auc": np.full((K,), np.nan, dtype=np.float64),
|
|
||||||
"included": np.zeros((K,), dtype=bool),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
|
|
||||||
metrics_rows.append({"metric": "mrr", "value": float("nan")})
|
metrics_rows.append({"metric": "mrr", "value": float("nan")})
|
||||||
for k in [1, 3, 5, 10, 20]:
|
for k in [1, 3, 5, 10, 20]:
|
||||||
metrics_rows.append(
|
metrics_rows.append(
|
||||||
{"metric": f"hitrate_at_{k}", "value": float("nan")})
|
{"metric": f"hitrate_at_{k}", "value": float("nan")})
|
||||||
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
return metrics_rows, diag_df
|
||||||
return metrics_rows, per_cause_df
|
|
||||||
|
|
||||||
# If no eligible, keep coverage but leave accuracy-like metrics as NaN.
|
# If no eligible, keep coverage but leave accuracy-like metrics as NaN.
|
||||||
if n_eligible == 0:
|
if n_eligible == 0:
|
||||||
per_cause_df = pd.DataFrame(
|
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
|
||||||
{
|
|
||||||
"cause_id": np.arange(K, dtype=np.int64),
|
|
||||||
"n_pos": np.zeros((K,), dtype=np.int64),
|
|
||||||
"n_neg": np.full((K,), n_records_total, dtype=np.int64),
|
|
||||||
"auc": np.full((K,), np.nan, dtype=np.float64),
|
|
||||||
"included": np.zeros((K,), dtype=bool),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
|
|
||||||
metrics_rows.append({"metric": "mrr", "value": float("nan")})
|
metrics_rows.append({"metric": "mrr", "value": float("nan")})
|
||||||
for k in [1, 3, 5, 10, 20]:
|
for k in [1, 3, 5, 10, 20]:
|
||||||
metrics_rows.append(
|
metrics_rows.append(
|
||||||
{"metric": f"hitrate_at_{k}", "value": float("nan")})
|
{"metric": f"hitrate_at_{k}", "value": float("nan")})
|
||||||
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
return metrics_rows, diag_df
|
||||||
return metrics_rows, per_cause_df
|
|
||||||
|
|
||||||
scores_e = scores[eligible]
|
scores_e = scores[eligible]
|
||||||
y_e = y_next[eligible]
|
y_e = y_next[eligible]
|
||||||
|
|
||||||
pred = scores_e.argmax(axis=1)
|
pred = scores_e.argmax(axis=1)
|
||||||
acc = float((pred == y_e).mean())
|
acc = float((pred == y_e).mean())
|
||||||
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
|
metrics_rows.append({"metric": "hitrate_at_1", "value": acc})
|
||||||
|
|
||||||
# MRR
|
# MRR
|
||||||
order = np.argsort(-scores_e, axis=1, kind="mergesort")
|
order = np.argsort(-scores_e, axis=1, kind="mergesort")
|
||||||
@@ -158,17 +154,53 @@ def _compute_next_event_metrics(
|
|||||||
metrics_rows.append({"metric": f"hitrate_at_{k}",
|
metrics_rows.append({"metric": f"hitrate_at_{k}",
|
||||||
"value": float(hit.mean())})
|
"value": float(hit.mean())})
|
||||||
|
|
||||||
# Macro OvR AUC per cause (optional)
|
# Diagnostic per-cause AUC is computed outside (needs future events), so keep placeholder here.
|
||||||
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
|
_ = min_pos
|
||||||
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
|
return metrics_rows, diag_df
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_within_window_auc(
|
||||||
|
*,
|
||||||
|
scores: np.ndarray,
|
||||||
|
records: list,
|
||||||
|
tau_short: float,
|
||||||
|
min_pos: int,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Diagnostic-only per-cause AUC.
|
||||||
|
|
||||||
|
Label definition (event-driven, approximate; no IPCW):
|
||||||
|
y[i,k]=1 iff at least one event of cause k occurs in (t0, t0+tau_short].
|
||||||
|
"""
|
||||||
|
n_records = int(len(records))
|
||||||
|
if n_records == 0:
|
||||||
|
return pd.DataFrame(
|
||||||
|
columns=["cause_id", "n_pos", "n_neg", "auc", "included"],
|
||||||
|
)
|
||||||
|
|
||||||
|
K = int(scores.shape[1])
|
||||||
|
y = np.zeros((n_records, K), dtype=np.int8)
|
||||||
|
tau = float(tau_short)
|
||||||
|
|
||||||
|
# Build labels from future events.
|
||||||
|
for i, r in enumerate(records):
|
||||||
|
if r.future_causes.size == 0:
|
||||||
|
continue
|
||||||
|
m = r.future_dt_years <= tau
|
||||||
|
if not np.any(m):
|
||||||
|
continue
|
||||||
|
y[i, r.future_causes[m]] = 1
|
||||||
|
|
||||||
|
n_pos = y.sum(axis=0).astype(np.int64)
|
||||||
|
n_neg = (int(n_records) - n_pos).astype(np.int64)
|
||||||
|
|
||||||
auc = np.full((K,), np.nan, dtype=np.float64)
|
auc = np.full((K,), np.nan, dtype=np.float64)
|
||||||
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0))
|
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0))
|
||||||
for k in candidates:
|
for k in candidates:
|
||||||
auc[k] = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
|
auc[k] = roc_auc_ovr(y[:, k].astype(np.int32),
|
||||||
|
scores[:, k].astype(np.float64))
|
||||||
|
|
||||||
included = (n_pos >= int(min_pos)) & (n_neg > 0)
|
included = (n_pos >= int(min_pos)) & (n_neg > 0)
|
||||||
per_cause_df = pd.DataFrame(
|
return pd.DataFrame(
|
||||||
{
|
{
|
||||||
"cause_id": np.arange(K, dtype=np.int64),
|
"cause_id": np.arange(K, dtype=np.int64),
|
||||||
"n_pos": n_pos,
|
"n_pos": n_pos,
|
||||||
@@ -178,15 +210,6 @@ def _compute_next_event_metrics(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
aucs = auc[np.isfinite(auc)]
|
|
||||||
if aucs.size > 0:
|
|
||||||
metrics_rows.append(
|
|
||||||
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
|
|
||||||
else:
|
|
||||||
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
|
||||||
|
|
||||||
return metrics_rows, per_cause_df
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
@@ -246,19 +269,7 @@ def main() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Overall (preserve existing output files/shape)
|
# Overall (preserve existing output files/shape)
|
||||||
metrics_rows, per_cause_df = _compute_next_event_metrics(
|
# Strict protocol: evaluate independently per age bin (no mixing).
|
||||||
scores=scores,
|
|
||||||
y_next=y_next,
|
|
||||||
tau_short=tau,
|
|
||||||
min_pos=int(args.min_pos),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
|
|
||||||
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
|
|
||||||
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
|
|
||||||
per_cause_df.to_csv(out_pc, index=False)
|
|
||||||
|
|
||||||
# By age bin (new outputs)
|
|
||||||
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
|
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
|
||||||
age_bins_days = age_bins_years * DAYS_PER_YEAR
|
age_bins_days = age_bins_years * DAYS_PER_YEAR
|
||||||
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
|
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
|
||||||
@@ -266,7 +277,7 @@ def main() -> None:
|
|||||||
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
|
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
|
||||||
|
|
||||||
per_bin_metric_rows: List[dict] = []
|
per_bin_metric_rows: List[dict] = []
|
||||||
per_bin_cause_parts: List[pd.DataFrame] = []
|
per_bin_auc_parts: List[pd.DataFrame] = []
|
||||||
for b in range(len(age_bins_years) - 1):
|
for b in range(len(age_bins_years) - 1):
|
||||||
lo = float(age_bins_years[b])
|
lo = float(age_bins_years[b])
|
||||||
hi = float(age_bins_years[b + 1])
|
hi = float(age_bins_years[b + 1])
|
||||||
@@ -274,6 +285,7 @@ def main() -> None:
|
|||||||
m = bin_idx == b
|
m = bin_idx == b
|
||||||
m_scores = scores[m]
|
m_scores = scores[m]
|
||||||
m_y = y_next[m]
|
m_y = y_next[m]
|
||||||
|
m_records = [r for r, keep in zip(records, m) if bool(keep)]
|
||||||
m_rows, m_pc = _compute_next_event_metrics(
|
m_rows, m_pc = _compute_next_event_metrics(
|
||||||
scores=m_scores,
|
scores=m_scores,
|
||||||
y_next=m_y,
|
y_next=m_y,
|
||||||
@@ -282,25 +294,32 @@ def main() -> None:
|
|||||||
)
|
)
|
||||||
for row in m_rows:
|
for row in m_rows:
|
||||||
per_bin_metric_rows.append({"age_bin": label, **row})
|
per_bin_metric_rows.append({"age_bin": label, **row})
|
||||||
m_pc = m_pc.copy()
|
m_auc = _compute_within_window_auc(
|
||||||
m_pc.insert(0, "age_bin", label)
|
scores=m_scores,
|
||||||
per_bin_cause_parts.append(m_pc)
|
records=m_records,
|
||||||
|
tau_short=tau,
|
||||||
|
min_pos=int(args.min_pos),
|
||||||
|
)
|
||||||
|
m_auc.insert(0, "age_bin", label)
|
||||||
|
m_auc.insert(1, "tau_short_years", float(tau))
|
||||||
|
per_bin_auc_parts.append(m_auc)
|
||||||
|
|
||||||
out_metrics_bins = os.path.join(
|
out_metrics_bins = os.path.join(
|
||||||
run_dir, "next_event_metrics_by_age_bin.csv")
|
run_dir, "next_event_metrics_by_age_bin.csv")
|
||||||
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
|
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
|
||||||
out_pc_bins = os.path.join(run_dir, "next_event_per_cause_by_age_bin.csv")
|
|
||||||
if per_bin_cause_parts:
|
|
||||||
pd.concat(per_bin_cause_parts, ignore_index=True).to_csv(
|
|
||||||
out_pc_bins, index=False)
|
|
||||||
else:
|
|
||||||
pd.DataFrame(columns=["age_bin", "cause_id", "n_pos", "n_neg",
|
|
||||||
"auc", "included"]).to_csv(out_pc_bins, index=False)
|
|
||||||
|
|
||||||
print(f"Wrote {out_metrics}")
|
out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv")
|
||||||
print(f"Wrote {out_pc}")
|
if per_bin_auc_parts:
|
||||||
|
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
|
||||||
|
out_auc_bins, index=False)
|
||||||
|
else:
|
||||||
|
pd.DataFrame(columns=["age_bin", "tau_short_years", "cause_id", "n_pos",
|
||||||
|
"n_neg", "auc", "included"]).to_csv(out_auc_bins, index=False)
|
||||||
|
|
||||||
|
print("PRIMARY METRICS: Hit@K / MRR are reported per age bin.")
|
||||||
|
print("DIAGNOSTICS ONLY: AUC table is event-driven approximate (no IPCW).")
|
||||||
print(f"Wrote {out_metrics_bins}")
|
print(f"Wrote {out_metrics_bins}")
|
||||||
print(f"Wrote {out_pc_bins}")
|
print(f"Wrote {out_auc_bins}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user