Enhance next-event evaluation with age-bin metrics and diagnostic AUC outputs

This commit is contained in:
2026-01-17 15:31:12 +08:00
parent 197842b1a6
commit fcd948818c
2 changed files with 348 additions and 183 deletions

View File

@@ -1,18 +1,25 @@
"""Horizon-capture evaluation. """Horizon-capture evaluation (event-driven, age-stratified).
DISCLAIMERS (important): This script implements the protocol described in 评估方案.md:
- The reported AUC is "time-dependent" only because the label depends on the chosen horizon $\tau$.
Without explicit censoring / follow-up end times, this is NOT a classical risk-set AUC with IPCW.
Use it for model comparison and diagnostics, not strict statistical interpretation.
- The reported Brier scores are unadjusted diagnostic/proxy metrics (no censoring adjustment). - Age-stratified evaluation: metrics are computed independently within each age bin (no mixing).
Use them to detect probability-mass compression / numerical stability issues; do not claim - Event-driven inclusion: each (person, age_bin) yields a record iff DOA <= bin upper bound and
calibrated absolute risk. there is at least one disease event in the bin; baseline t0 is sampled randomly from in-bin
disease events with t0 >= DOA.
- No follow-up completeness filtering (no t0+tau <= t_end constraint).
Primary outputs per age bin:
- Top-K Event Capture@tau (event-count based)
- WorkloadYield curves (Top-p% people by a person-level horizon score)
Secondary (diagnostic-only) outputs per age bin:
- Approximate event-driven AUC / Brier (no IPCW, no censoring adjustment)
""" """
import argparse import argparse
import math
import os import os
from typing import Dict, List, Sequence from typing import Dict, List, Sequence, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@@ -40,6 +47,7 @@ from utils import (
roc_auc_ovr, roc_auc_ovr,
seed_everything, seed_everything,
topk_indices, topk_indices,
DAYS_PER_YEAR,
) )
@@ -58,8 +66,8 @@ def parse_args() -> argparse.Namespace:
"--age_bins", "--age_bins",
type=str, type=str,
nargs="+", nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"], default=["40", "45", "50", "55", "60", "65", "70", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)", help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
) )
p.add_argument( p.add_argument(
@@ -75,7 +83,14 @@ def parse_args() -> argparse.Namespace:
"--topk_list", "--topk_list",
type=int, type=int,
nargs="+", nargs="+",
default=[1, 5, 10, 20, 50], default=[5, 10, 20, 50],
)
p.add_argument(
"--workload_fracs",
type=float,
nargs="+",
default=[0.01, 0.02, 0.05, 0.1, 0.2, 0.5],
help="Fractions for workloadyield curves (Top-p%% people).",
) )
p.add_argument( p.add_argument(
"--no_tqdm", "--no_tqdm",
@@ -85,6 +100,33 @@ def parse_args() -> argparse.Namespace:
return p.parse_args() return p.parse_args()
def _format_age_bin_label(lo: float, hi: float) -> str:
if np.isinf(hi):
return f"[{lo}, inf)"
return f"[{lo}, {hi})"
def _assign_age_bin_idx(t0_days: np.ndarray, age_bins_years: Sequence[float]) -> np.ndarray:
age_bins_years = np.asarray(list(age_bins_years), dtype=np.float64)
age_bins_days = age_bins_years * DAYS_PER_YEAR
return np.searchsorted(age_bins_days, t0_days, side="left") - 1
def _event_counts_within_tau(
n_records: int,
event_record_idx: np.ndarray,
event_dt_years: np.ndarray,
tau_years: float,
) -> np.ndarray:
"""Count events within (t0, t0+tau] per record (event-count, not unique causes)."""
if event_record_idx.size == 0:
return np.zeros((n_records,), dtype=np.int64)
m = event_dt_years <= float(tau_years)
if not np.any(m):
return np.zeros((n_records,), dtype=np.int64)
return np.bincount(event_record_idx[m], minlength=n_records).astype(np.int64)
def build_labels_within_tau_flat( def build_labels_within_tau_flat(
n_records: int, n_records: int,
n_causes: int, n_causes: int,
@@ -148,8 +190,8 @@ def main() -> None:
) )
# Print disclaimers every run (requested) # Print disclaimers every run (requested)
print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).") print("PRIMARY METRICS: event-count Capture@K and WorkloadYield, computed independently per age bin.")
print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).") print("DIAGNOSTICS ONLY: AUC/Brier below are event-driven approximations (no IPCW / censoring adjustment).")
scores = predict_cifs( scores = predict_cifs(
model, model,
@@ -171,131 +213,235 @@ def main() -> None:
raise ValueError("Record count mismatch") raise ValueError("Record count mismatch")
# Pre-flatten all future events once to avoid repeated per-record scans. # Pre-flatten all future events once to avoid repeated per-record scans.
# NOTE: these are event-level arrays (not unique causes), suitable for event-count Capture@K.
evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K) evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K)
per_tau_rows: List[Dict[str, object]] = [] # Assign each record to an age bin (based on t0; by construction t0 is within the bin).
per_cause_rows: List[Dict[str, object]] = [] t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
bin_idx = _assign_age_bin_idx(t0_days, age_bins_years)
age_bins_years_arr = np.asarray(list(age_bins_years), dtype=np.float64)
capture_rows: List[Dict[str, object]] = []
workload_rows: List[Dict[str, object]] = [] workload_rows: List[Dict[str, object]] = []
horizon_iter = enumerate(horizons) # Diagnostics (optional): approximate event-driven AUC/Brier computed per bin.
diag_rows: List[Dict[str, object]] = []
diag_per_cause_parts: List[pd.DataFrame] = []
bins_iter = range(len(age_bins_years_arr) - 1)
if show_progress and tqdm is not None: if show_progress and tqdm is not None:
horizon_iter = tqdm(horizon_iter, total=len( bins_iter = tqdm(bins_iter, total=len(
horizons), desc="Metrics by horizon") age_bins_years_arr) - 1, desc="Age bins")
for h_idx, tau in horizon_iter: for b in bins_iter:
s_tau = scores[:, :, h_idx] lo = float(age_bins_years_arr[b])
y_tau = build_labels_within_tau_flat( hi = float(age_bins_years_arr[b + 1])
N, K, evt_rec_idx, evt_cause, evt_dt, tau) age_label = _format_age_bin_label(lo, hi)
# Per-cause counts + Brier (vectorized) m_rec = bin_idx == b
n_pos = y_tau.sum(axis=0).astype(np.int64) n_bin = int(m_rec.sum())
n_neg = (int(N) - n_pos).astype(np.int64) if n_bin == 0:
continue
# Brier per cause: mean_i (y - s)^2 rec_idx_bin = np.flatnonzero(m_rec).astype(np.int32)
brier_per_cause = np.mean(
(y_tau.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0)
brier_macro = float(np.mean(brier_per_cause)) if K > 0 else float("nan")
brier_weighted = float(np.sum(
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
# AUC: compute only for causes with enough positives and at least 1 negative # Filter events to this bin's records once.
auc = np.full((K,), np.nan, dtype=np.float64) m_evt_bin = m_rec[evt_rec_idx] if evt_rec_idx.size > 0 else np.zeros(
min_pos = int(args.min_pos) (0,), dtype=bool)
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0)) evt_rec_idx_b = evt_rec_idx[m_evt_bin]
for k in candidates: evt_cause_b = evt_cause[m_evt_bin]
auc[k] = roc_auc_ovr(y_tau[:, k].astype( evt_dt_b = evt_dt[m_evt_bin]
np.int32), s_tau[:, k].astype(np.float64))
finite_auc = auc[np.isfinite(auc)] horizon_iter = enumerate(horizons)
auc_macro = float(np.mean(finite_auc) if show_progress and tqdm is not None:
) if finite_auc.size > 0 else float("nan") horizon_iter = tqdm(horizon_iter, total=len(
w_mask = np.isfinite(auc) horizons), desc=f"Horizons {age_label}")
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
n_valid_auc = int(np.isfinite(auc).sum())
# Append per-cause rows (vectorized via DataFrame to avoid Python loops) # Precompute a local index mapping for diagnostics label building.
per_cause_rows.append( local_map = np.full((N,), -1, dtype=np.int32)
pd.DataFrame( local_map[rec_idx_bin] = np.arange(n_bin, dtype=np.int32)
{
"tau_years": float(tau), for h_idx, tau in horizon_iter:
"cause_id": np.arange(K, dtype=np.int64), s_tau_all = scores[:, :, h_idx]
"n_pos": n_pos, s_tau = s_tau_all[m_rec]
"n_neg": n_neg,
"auc": auc, # -------------------------
"brier": brier_per_cause, # Primary metric: Top-K Event Capture@tau (event-count based)
} # -------------------------
denom_events = int(np.sum(evt_dt_b <= float(tau))
) if evt_dt_b.size > 0 else 0
if denom_events == 0:
for topk in args.topk_list:
capture_rows.append(
{
"age_bin": age_label,
"tau_years": float(tau),
"topk": int(topk),
"capture_at_k": float("nan"),
"denom_events": int(0),
"numer_events": int(0),
"n_records": int(n_bin),
"n_causes": int(K),
}
)
else:
m_evt_tau = evt_dt_b <= float(tau)
evt_rec_idx_tau = evt_rec_idx_b[m_evt_tau]
evt_cause_tau = evt_cause_b[m_evt_tau]
# For each K, compute whether each event's cause is in that record's Top-K list.
for topk in args.topk_list:
topk = int(topk)
idx = topk_indices(s_tau_all, topk) # shape (N, topk)
idx_for_events = idx[evt_rec_idx_tau]
hits = (idx_for_events ==
evt_cause_tau[:, None]).any(axis=1)
numer_events = int(hits.sum())
capture = float(numer_events / denom_events)
capture_rows.append(
{
"age_bin": age_label,
"tau_years": float(tau),
"topk": int(topk),
"capture_at_k": capture,
"denom_events": int(denom_events),
"numer_events": int(numer_events),
"n_records": int(n_bin),
"n_causes": int(K),
}
)
# -------------------------
# Primary metric: WorkloadYield (Top-p% people)
# -------------------------
# Person-level score: max_k CIF_k(tau). This is used only for workloadyield ranking.
person_score = s_tau.max(axis=1) if K > 0 else np.zeros(
(n_bin,), dtype=np.float64)
order = np.argsort(-person_score, kind="mergesort")
counts_per_record = _event_counts_within_tau(
n_bin, local_map[evt_rec_idx_b], evt_dt_b, tau)
total_events = int(counts_per_record.sum())
overall_events_per_person = (
total_events / float(n_bin)) if n_bin > 0 else float("nan")
for frac in args.workload_fracs:
frac = float(frac)
if frac <= 0.0:
continue
n_sel = int(math.ceil(frac * n_bin))
n_sel = min(max(n_sel, 1), n_bin)
sel_local = order[:n_sel]
events_captured = int(counts_per_record[sel_local].sum())
capture_rate = float(
events_captured / total_events) if total_events > 0 else float("nan")
selected_events_per_person = (
events_captured / float(n_sel)) if n_sel > 0 else float("nan")
lift = (selected_events_per_person /
overall_events_per_person) if overall_events_per_person > 0 else float("nan")
workload_rows.append(
{
"age_bin": age_label,
"tau_years": float(tau),
"frac_selected": float(frac),
"n_selected": int(n_sel),
"n_records": int(n_bin),
"total_events": int(total_events),
"events_captured": int(events_captured),
"capture_rate": capture_rate,
"lift_events_per_person": float(lift),
"person_score_def": "max_k_CIF_k(tau)",
}
)
# -------------------------
# Diagnostics (optional): approximate event-driven AUC/Brier
# -------------------------
# Convert event-level data to binary labels y[i,k]=1 iff >=1 event of cause k within tau.
y_tau_bin = np.zeros((n_bin, K), dtype=np.int8)
if evt_dt_b.size > 0:
m_evt_tau = evt_dt_b <= float(tau)
if np.any(m_evt_tau):
rec_local = local_map[evt_rec_idx_b[m_evt_tau]]
valid = rec_local >= 0
y_tau_bin[rec_local[valid],
evt_cause_b[m_evt_tau][valid]] = 1
n_pos = y_tau_bin.sum(axis=0).astype(np.int64)
n_neg = (int(n_bin) - n_pos).astype(np.int64)
brier_per_cause = np.mean(
(y_tau_bin.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0
) )
) brier_macro = float(np.mean(brier_per_cause)
) if K > 0 else float("nan")
brier_weighted = float(np.sum(
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
# Business metrics for each topK auc = np.full((K,), np.nan, dtype=np.float64)
denom_true_pairs = int(y_tau.sum()) min_pos = int(args.min_pos)
for topk in args.topk_list: candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
topk = int(topk) for k in candidates:
idx = topk_indices(s_tau, topk) auc[k] = roc_auc_ovr(y_tau_bin[:, k].astype(
captured = np.take_along_axis(y_tau, idx, axis=1) np.int32), s_tau[:, k].astype(np.float64))
hits = captured.sum(axis=1).astype(np.float64)
true_cnt = y_tau.sum(axis=1).astype(np.float64)
precision_like = hits / float(min(topk, K)) finite_auc = auc[np.isfinite(auc)]
mean_precision = float(np.mean(precision_like) auc_macro = float(np.mean(finite_auc)
) if N > 0 else float("nan") ) if finite_auc.size > 0 else float("nan")
w_mask = np.isfinite(auc)
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
n_valid_auc = int(np.isfinite(auc).sum())
mask_has_true = true_cnt > 0 diag_rows.append(
recall_like = np.full((N,), np.nan, dtype=np.float64)
recall_like[mask_has_true] = hits[mask_has_true] / \
true_cnt[mask_has_true]
mean_recall = float(np.nanmean(recall_like)) if np.any(
mask_has_true) else float("nan")
median_recall = float(np.nanmedian(recall_like)) if np.any(
mask_has_true) else float("nan")
numer_captured_pairs = int(captured.sum())
pop_capture_rate = float(
numer_captured_pairs / denom_true_pairs) if denom_true_pairs > 0 else float("nan")
workload_rows.append(
{ {
"age_bin": age_label,
"tau_years": float(tau), "tau_years": float(tau),
"topk": int(topk), "n_records": int(n_bin),
"population_capture_rate": pop_capture_rate, "n_causes": int(K),
"mean_precision_like": mean_precision, "auc_macro": auc_macro,
"mean_recall_like": mean_recall, "auc_weighted_by_npos": auc_weighted,
"median_recall_like": median_recall, "n_causes_valid_auc": int(n_valid_auc),
"denom_true_pairs": denom_true_pairs, "brier_macro": brier_macro,
"numer_captured_pairs": numer_captured_pairs, "brier_weighted_by_npos": brier_weighted,
} }
) )
per_tau_rows.append( diag_per_cause_parts.append(
{ pd.DataFrame(
"tau_years": float(tau), {
"n_records": int(N), "age_bin": age_label,
"n_causes": int(K), "tau_years": float(tau),
"auc_macro": auc_macro, "cause_id": np.arange(K, dtype=np.int64),
"auc_weighted_by_npos": auc_weighted, "n_pos": n_pos,
"n_causes_valid_auc": int(n_valid_auc), "n_neg": n_neg,
"brier_macro": brier_macro, "auc": auc,
"brier_weighted_by_npos": brier_weighted, "brier": brier_per_cause,
"total_true_pairs": denom_true_pairs, }
} )
) )
out_metrics = os.path.join(run_dir, "horizon_metrics.csv") out_capture = os.path.join(run_dir, "horizon_capture.csv")
out_pc = os.path.join(run_dir, "horizon_per_cause.csv")
out_wy = os.path.join(run_dir, "workload_yield.csv") out_wy = os.path.join(run_dir, "workload_yield.csv")
out_diag = os.path.join(run_dir, "horizon_metrics.csv")
out_diag_pc = os.path.join(run_dir, "horizon_per_cause.csv")
pd.DataFrame(per_tau_rows).to_csv(out_metrics, index=False) pd.DataFrame(capture_rows).to_csv(out_capture, index=False)
if per_cause_rows:
pd.concat(per_cause_rows, ignore_index=True).to_csv(out_pc, index=False)
else:
pd.DataFrame(columns=["tau_years", "cause_id", "n_pos",
"n_neg", "auc", "brier"]).to_csv(out_pc, index=False)
pd.DataFrame(workload_rows).to_csv(out_wy, index=False) pd.DataFrame(workload_rows).to_csv(out_wy, index=False)
pd.DataFrame(diag_rows).to_csv(out_diag, index=False)
if diag_per_cause_parts:
pd.concat(diag_per_cause_parts, ignore_index=True).to_csv(
out_diag_pc, index=False)
else:
pd.DataFrame(columns=["age_bin", "tau_years", "cause_id", "n_pos",
"n_neg", "auc", "brier"]).to_csv(out_diag_pc, index=False)
print(f"Wrote {out_metrics}") print(f"Wrote {out_capture}")
print(f"Wrote {out_pc}")
print(f"Wrote {out_wy}") print(f"Wrote {out_wy}")
print(f"Wrote {out_diag} (diagnostic-only)")
print(f"Wrote {out_diag_pc} (diagnostic-only)")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -41,8 +41,8 @@ def parse_args() -> argparse.Namespace:
"--age_bins", "--age_bins",
type=str, type=str,
nargs="+", nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"], default=["40", "45", "50", "55", "60", "65", "70", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)", help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
) )
p.add_argument( p.add_argument(
@@ -80,9 +80,14 @@ def _compute_next_event_metrics(
tau_short: float, tau_short: float,
min_pos: int, min_pos: int,
) -> tuple[list[dict], pd.DataFrame]: ) -> tuple[list[dict], pd.DataFrame]:
"""Compute next-event metrics on a given subset. """Compute next-event *primary* metrics on a given subset.
Definitions are unchanged from the original script. Implements 评估方案.md (Next-event):
- score_k = CIF_k(tau_short)
- Hit@K / MRR are computed on records with an observed next-event.
Returns (metrics_rows, diag_df). diag_df is a diagnostic per-cause AUC table
based on whether the cause occurs within (t0, t0+tau_short] (display-only).
""" """
n_records_total = int(y_next.size) n_records_total = int(y_next.size)
eligible = y_next >= 0 eligible = y_next >= 0
@@ -99,49 +104,40 @@ def _compute_next_event_metrics(
{"metric": "tau_short_years", "value": float(tau_short)}) {"metric": "tau_short_years", "value": float(tau_short)})
K = int(scores.shape[1]) K = int(scores.shape[1])
# Diagnostic: build per-cause AUC using within-window labels.
# This is NOT a primary metric (no IPCW / censoring adjustment).
diag_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.zeros((K,), dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
if n_records_total == 0: if n_records_total == 0:
per_cause_df = pd.DataFrame( metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.zeros((K,), dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")}) metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]: for k in [1, 3, 5, 10, 20]:
metrics_rows.append( metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")}) {"metric": f"hitrate_at_{k}", "value": float("nan")})
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) return metrics_rows, diag_df
return metrics_rows, per_cause_df
# If no eligible, keep coverage but leave accuracy-like metrics as NaN. # If no eligible, keep coverage but leave accuracy-like metrics as NaN.
if n_eligible == 0: if n_eligible == 0:
per_cause_df = pd.DataFrame( metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.full((K,), n_records_total, dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")}) metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]: for k in [1, 3, 5, 10, 20]:
metrics_rows.append( metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")}) {"metric": f"hitrate_at_{k}", "value": float("nan")})
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) return metrics_rows, diag_df
return metrics_rows, per_cause_df
scores_e = scores[eligible] scores_e = scores[eligible]
y_e = y_next[eligible] y_e = y_next[eligible]
pred = scores_e.argmax(axis=1) pred = scores_e.argmax(axis=1)
acc = float((pred == y_e).mean()) acc = float((pred == y_e).mean())
metrics_rows.append({"metric": "top1_accuracy", "value": acc}) metrics_rows.append({"metric": "hitrate_at_1", "value": acc})
# MRR # MRR
order = np.argsort(-scores_e, axis=1, kind="mergesort") order = np.argsort(-scores_e, axis=1, kind="mergesort")
@@ -158,17 +154,53 @@ def _compute_next_event_metrics(
metrics_rows.append({"metric": f"hitrate_at_{k}", metrics_rows.append({"metric": f"hitrate_at_{k}",
"value": float(hit.mean())}) "value": float(hit.mean())})
# Macro OvR AUC per cause (optional) # Diagnostic per-cause AUC is computed outside (needs future events), so keep placeholder here.
n_pos = np.bincount(y_e, minlength=K).astype(np.int64) _ = min_pos
n_neg = (int(y_e.size) - n_pos).astype(np.int64) return metrics_rows, diag_df
def _compute_within_window_auc(
*,
scores: np.ndarray,
records: list,
tau_short: float,
min_pos: int,
) -> pd.DataFrame:
"""Diagnostic-only per-cause AUC.
Label definition (event-driven, approximate; no IPCW):
y[i,k]=1 iff at least one event of cause k occurs in (t0, t0+tau_short].
"""
n_records = int(len(records))
if n_records == 0:
return pd.DataFrame(
columns=["cause_id", "n_pos", "n_neg", "auc", "included"],
)
K = int(scores.shape[1])
y = np.zeros((n_records, K), dtype=np.int8)
tau = float(tau_short)
# Build labels from future events.
for i, r in enumerate(records):
if r.future_causes.size == 0:
continue
m = r.future_dt_years <= tau
if not np.any(m):
continue
y[i, r.future_causes[m]] = 1
n_pos = y.sum(axis=0).astype(np.int64)
n_neg = (int(n_records) - n_pos).astype(np.int64)
auc = np.full((K,), np.nan, dtype=np.float64) auc = np.full((K,), np.nan, dtype=np.float64)
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0)) candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0))
for k in candidates: for k in candidates:
auc[k] = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k]) auc[k] = roc_auc_ovr(y[:, k].astype(np.int32),
scores[:, k].astype(np.float64))
included = (n_pos >= int(min_pos)) & (n_neg > 0) included = (n_pos >= int(min_pos)) & (n_neg > 0)
per_cause_df = pd.DataFrame( return pd.DataFrame(
{ {
"cause_id": np.arange(K, dtype=np.int64), "cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos, "n_pos": n_pos,
@@ -178,15 +210,6 @@ def _compute_next_event_metrics(
} }
) )
aucs = auc[np.isfinite(auc)]
if aucs.size > 0:
metrics_rows.append(
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
else:
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
@@ -246,19 +269,7 @@ def main() -> None:
) )
# Overall (preserve existing output files/shape) # Overall (preserve existing output files/shape)
metrics_rows, per_cause_df = _compute_next_event_metrics( # Strict protocol: evaluate independently per age bin (no mixing).
scores=scores,
y_next=y_next,
tau_short=tau,
min_pos=int(args.min_pos),
)
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
per_cause_df.to_csv(out_pc, index=False)
# By age bin (new outputs)
age_bins_years = np.asarray(age_bins_years, dtype=np.float64) age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
age_bins_days = age_bins_years * DAYS_PER_YEAR age_bins_days = age_bins_years * DAYS_PER_YEAR
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1}) # Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
@@ -266,7 +277,7 @@ def main() -> None:
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1 bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
per_bin_metric_rows: List[dict] = [] per_bin_metric_rows: List[dict] = []
per_bin_cause_parts: List[pd.DataFrame] = [] per_bin_auc_parts: List[pd.DataFrame] = []
for b in range(len(age_bins_years) - 1): for b in range(len(age_bins_years) - 1):
lo = float(age_bins_years[b]) lo = float(age_bins_years[b])
hi = float(age_bins_years[b + 1]) hi = float(age_bins_years[b + 1])
@@ -274,6 +285,7 @@ def main() -> None:
m = bin_idx == b m = bin_idx == b
m_scores = scores[m] m_scores = scores[m]
m_y = y_next[m] m_y = y_next[m]
m_records = [r for r, keep in zip(records, m) if bool(keep)]
m_rows, m_pc = _compute_next_event_metrics( m_rows, m_pc = _compute_next_event_metrics(
scores=m_scores, scores=m_scores,
y_next=m_y, y_next=m_y,
@@ -282,25 +294,32 @@ def main() -> None:
) )
for row in m_rows: for row in m_rows:
per_bin_metric_rows.append({"age_bin": label, **row}) per_bin_metric_rows.append({"age_bin": label, **row})
m_pc = m_pc.copy() m_auc = _compute_within_window_auc(
m_pc.insert(0, "age_bin", label) scores=m_scores,
per_bin_cause_parts.append(m_pc) records=m_records,
tau_short=tau,
min_pos=int(args.min_pos),
)
m_auc.insert(0, "age_bin", label)
m_auc.insert(1, "tau_short_years", float(tau))
per_bin_auc_parts.append(m_auc)
out_metrics_bins = os.path.join( out_metrics_bins = os.path.join(
run_dir, "next_event_metrics_by_age_bin.csv") run_dir, "next_event_metrics_by_age_bin.csv")
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False) pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
out_pc_bins = os.path.join(run_dir, "next_event_per_cause_by_age_bin.csv")
if per_bin_cause_parts:
pd.concat(per_bin_cause_parts, ignore_index=True).to_csv(
out_pc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "cause_id", "n_pos", "n_neg",
"auc", "included"]).to_csv(out_pc_bins, index=False)
print(f"Wrote {out_metrics}") out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv")
print(f"Wrote {out_pc}") if per_bin_auc_parts:
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
out_auc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "tau_short_years", "cause_id", "n_pos",
"n_neg", "auc", "included"]).to_csv(out_auc_bins, index=False)
print("PRIMARY METRICS: Hit@K / MRR are reported per age bin.")
print("DIAGNOSTICS ONLY: AUC table is event-driven approximate (no IPCW).")
print(f"Wrote {out_metrics_bins}") print(f"Wrote {out_metrics_bins}")
print(f"Wrote {out_pc_bins}") print(f"Wrote {out_auc_bins}")
if __name__ == "__main__": if __name__ == "__main__":