Enhance next-event evaluation with age-bin metrics and diagnostic AUC outputs

This commit is contained in:
2026-01-17 15:31:12 +08:00
parent 197842b1a6
commit fcd948818c
2 changed files with 348 additions and 183 deletions

View File

@@ -1,18 +1,25 @@
"""Horizon-capture evaluation.
"""Horizon-capture evaluation (event-driven, age-stratified).
DISCLAIMERS (important):
- The reported AUC is "time-dependent" only because the label depends on the chosen horizon $\tau$.
Without explicit censoring / follow-up end times, this is NOT a classical risk-set AUC with IPCW.
Use it for model comparison and diagnostics, not strict statistical interpretation.
This script implements the protocol described in 评估方案.md:
- The reported Brier scores are unadjusted diagnostic/proxy metrics (no censoring adjustment).
Use them to detect probability-mass compression / numerical stability issues; do not claim
calibrated absolute risk.
- Age-stratified evaluation: metrics are computed independently within each age bin (no mixing).
- Event-driven inclusion: each (person, age_bin) yields a record iff DOA <= bin upper bound and
there is at least one disease event in the bin; baseline t0 is sampled randomly from in-bin
disease events with t0 >= DOA.
- No follow-up completeness filtering (no t0+tau <= t_end constraint).
Primary outputs per age bin:
- Top-K Event Capture@tau (event-count based)
- WorkloadYield curves (Top-p% people by a person-level horizon score)
Secondary (diagnostic-only) outputs per age bin:
- Approximate event-driven AUC / Brier (no IPCW, no censoring adjustment)
"""
import argparse
import math
import os
from typing import Dict, List, Sequence
from typing import Dict, List, Sequence, Tuple
import numpy as np
import pandas as pd
@@ -40,6 +47,7 @@ from utils import (
roc_auc_ovr,
seed_everything,
topk_indices,
DAYS_PER_YEAR,
)
@@ -58,8 +66,8 @@ def parse_args() -> argparse.Namespace:
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
)
p.add_argument(
@@ -75,7 +83,14 @@ def parse_args() -> argparse.Namespace:
"--topk_list",
type=int,
nargs="+",
default=[1, 5, 10, 20, 50],
default=[5, 10, 20, 50],
)
p.add_argument(
"--workload_fracs",
type=float,
nargs="+",
default=[0.01, 0.02, 0.05, 0.1, 0.2, 0.5],
help="Fractions for workloadyield curves (Top-p%% people).",
)
p.add_argument(
"--no_tqdm",
@@ -85,6 +100,33 @@ def parse_args() -> argparse.Namespace:
return p.parse_args()
def _format_age_bin_label(lo: float, hi: float) -> str:
if np.isinf(hi):
return f"[{lo}, inf)"
return f"[{lo}, {hi})"
def _assign_age_bin_idx(t0_days: np.ndarray, age_bins_years: Sequence[float]) -> np.ndarray:
age_bins_years = np.asarray(list(age_bins_years), dtype=np.float64)
age_bins_days = age_bins_years * DAYS_PER_YEAR
return np.searchsorted(age_bins_days, t0_days, side="left") - 1
def _event_counts_within_tau(
n_records: int,
event_record_idx: np.ndarray,
event_dt_years: np.ndarray,
tau_years: float,
) -> np.ndarray:
"""Count events within (t0, t0+tau] per record (event-count, not unique causes)."""
if event_record_idx.size == 0:
return np.zeros((n_records,), dtype=np.int64)
m = event_dt_years <= float(tau_years)
if not np.any(m):
return np.zeros((n_records,), dtype=np.int64)
return np.bincount(event_record_idx[m], minlength=n_records).astype(np.int64)
def build_labels_within_tau_flat(
n_records: int,
n_causes: int,
@@ -148,8 +190,8 @@ def main() -> None:
)
# Print disclaimers every run (requested)
print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).")
print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).")
print("PRIMARY METRICS: event-count Capture@K and WorkloadYield, computed independently per age bin.")
print("DIAGNOSTICS ONLY: AUC/Brier below are event-driven approximations (no IPCW / censoring adjustment).")
scores = predict_cifs(
model,
@@ -171,131 +213,235 @@ def main() -> None:
raise ValueError("Record count mismatch")
# Pre-flatten all future events once to avoid repeated per-record scans.
# NOTE: these are event-level arrays (not unique causes), suitable for event-count Capture@K.
evt_rec_idx, evt_cause, evt_dt = flatten_future_events(records, n_causes=K)
per_tau_rows: List[Dict[str, object]] = []
per_cause_rows: List[Dict[str, object]] = []
# Assign each record to an age bin (based on t0; by construction t0 is within the bin).
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
bin_idx = _assign_age_bin_idx(t0_days, age_bins_years)
age_bins_years_arr = np.asarray(list(age_bins_years), dtype=np.float64)
capture_rows: List[Dict[str, object]] = []
workload_rows: List[Dict[str, object]] = []
horizon_iter = enumerate(horizons)
# Diagnostics (optional): approximate event-driven AUC/Brier computed per bin.
diag_rows: List[Dict[str, object]] = []
diag_per_cause_parts: List[pd.DataFrame] = []
bins_iter = range(len(age_bins_years_arr) - 1)
if show_progress and tqdm is not None:
horizon_iter = tqdm(horizon_iter, total=len(
horizons), desc="Metrics by horizon")
bins_iter = tqdm(bins_iter, total=len(
age_bins_years_arr) - 1, desc="Age bins")
for h_idx, tau in horizon_iter:
s_tau = scores[:, :, h_idx]
y_tau = build_labels_within_tau_flat(
N, K, evt_rec_idx, evt_cause, evt_dt, tau)
for b in bins_iter:
lo = float(age_bins_years_arr[b])
hi = float(age_bins_years_arr[b + 1])
age_label = _format_age_bin_label(lo, hi)
# Per-cause counts + Brier (vectorized)
n_pos = y_tau.sum(axis=0).astype(np.int64)
n_neg = (int(N) - n_pos).astype(np.int64)
m_rec = bin_idx == b
n_bin = int(m_rec.sum())
if n_bin == 0:
continue
# Brier per cause: mean_i (y - s)^2
brier_per_cause = np.mean(
(y_tau.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0)
brier_macro = float(np.mean(brier_per_cause)) if K > 0 else float("nan")
brier_weighted = float(np.sum(
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
rec_idx_bin = np.flatnonzero(m_rec).astype(np.int32)
# AUC: compute only for causes with enough positives and at least 1 negative
auc = np.full((K,), np.nan, dtype=np.float64)
min_pos = int(args.min_pos)
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
for k in candidates:
auc[k] = roc_auc_ovr(y_tau[:, k].astype(
np.int32), s_tau[:, k].astype(np.float64))
# Filter events to this bin's records once.
m_evt_bin = m_rec[evt_rec_idx] if evt_rec_idx.size > 0 else np.zeros(
(0,), dtype=bool)
evt_rec_idx_b = evt_rec_idx[m_evt_bin]
evt_cause_b = evt_cause[m_evt_bin]
evt_dt_b = evt_dt[m_evt_bin]
finite_auc = auc[np.isfinite(auc)]
auc_macro = float(np.mean(finite_auc)
) if finite_auc.size > 0 else float("nan")
w_mask = np.isfinite(auc)
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
n_valid_auc = int(np.isfinite(auc).sum())
horizon_iter = enumerate(horizons)
if show_progress and tqdm is not None:
horizon_iter = tqdm(horizon_iter, total=len(
horizons), desc=f"Horizons {age_label}")
# Append per-cause rows (vectorized via DataFrame to avoid Python loops)
per_cause_rows.append(
pd.DataFrame(
{
"tau_years": float(tau),
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
"n_neg": n_neg,
"auc": auc,
"brier": brier_per_cause,
}
# Precompute a local index mapping for diagnostics label building.
local_map = np.full((N,), -1, dtype=np.int32)
local_map[rec_idx_bin] = np.arange(n_bin, dtype=np.int32)
for h_idx, tau in horizon_iter:
s_tau_all = scores[:, :, h_idx]
s_tau = s_tau_all[m_rec]
# -------------------------
# Primary metric: Top-K Event Capture@tau (event-count based)
# -------------------------
denom_events = int(np.sum(evt_dt_b <= float(tau))
) if evt_dt_b.size > 0 else 0
if denom_events == 0:
for topk in args.topk_list:
capture_rows.append(
{
"age_bin": age_label,
"tau_years": float(tau),
"topk": int(topk),
"capture_at_k": float("nan"),
"denom_events": int(0),
"numer_events": int(0),
"n_records": int(n_bin),
"n_causes": int(K),
}
)
else:
m_evt_tau = evt_dt_b <= float(tau)
evt_rec_idx_tau = evt_rec_idx_b[m_evt_tau]
evt_cause_tau = evt_cause_b[m_evt_tau]
# For each K, compute whether each event's cause is in that record's Top-K list.
for topk in args.topk_list:
topk = int(topk)
idx = topk_indices(s_tau_all, topk) # shape (N, topk)
idx_for_events = idx[evt_rec_idx_tau]
hits = (idx_for_events ==
evt_cause_tau[:, None]).any(axis=1)
numer_events = int(hits.sum())
capture = float(numer_events / denom_events)
capture_rows.append(
{
"age_bin": age_label,
"tau_years": float(tau),
"topk": int(topk),
"capture_at_k": capture,
"denom_events": int(denom_events),
"numer_events": int(numer_events),
"n_records": int(n_bin),
"n_causes": int(K),
}
)
# -------------------------
# Primary metric: WorkloadYield (Top-p% people)
# -------------------------
# Person-level score: max_k CIF_k(tau). This is used only for workloadyield ranking.
person_score = s_tau.max(axis=1) if K > 0 else np.zeros(
(n_bin,), dtype=np.float64)
order = np.argsort(-person_score, kind="mergesort")
counts_per_record = _event_counts_within_tau(
n_bin, local_map[evt_rec_idx_b], evt_dt_b, tau)
total_events = int(counts_per_record.sum())
overall_events_per_person = (
total_events / float(n_bin)) if n_bin > 0 else float("nan")
for frac in args.workload_fracs:
frac = float(frac)
if frac <= 0.0:
continue
n_sel = int(math.ceil(frac * n_bin))
n_sel = min(max(n_sel, 1), n_bin)
sel_local = order[:n_sel]
events_captured = int(counts_per_record[sel_local].sum())
capture_rate = float(
events_captured / total_events) if total_events > 0 else float("nan")
selected_events_per_person = (
events_captured / float(n_sel)) if n_sel > 0 else float("nan")
lift = (selected_events_per_person /
overall_events_per_person) if overall_events_per_person > 0 else float("nan")
workload_rows.append(
{
"age_bin": age_label,
"tau_years": float(tau),
"frac_selected": float(frac),
"n_selected": int(n_sel),
"n_records": int(n_bin),
"total_events": int(total_events),
"events_captured": int(events_captured),
"capture_rate": capture_rate,
"lift_events_per_person": float(lift),
"person_score_def": "max_k_CIF_k(tau)",
}
)
# -------------------------
# Diagnostics (optional): approximate event-driven AUC/Brier
# -------------------------
# Convert event-level data to binary labels y[i,k]=1 iff >=1 event of cause k within tau.
y_tau_bin = np.zeros((n_bin, K), dtype=np.int8)
if evt_dt_b.size > 0:
m_evt_tau = evt_dt_b <= float(tau)
if np.any(m_evt_tau):
rec_local = local_map[evt_rec_idx_b[m_evt_tau]]
valid = rec_local >= 0
y_tau_bin[rec_local[valid],
evt_cause_b[m_evt_tau][valid]] = 1
n_pos = y_tau_bin.sum(axis=0).astype(np.int64)
n_neg = (int(n_bin) - n_pos).astype(np.int64)
brier_per_cause = np.mean(
(y_tau_bin.astype(np.float64) - s_tau.astype(np.float64)) ** 2, axis=0
)
)
brier_macro = float(np.mean(brier_per_cause)
) if K > 0 else float("nan")
brier_weighted = float(np.sum(
brier_per_cause * n_pos) / np.sum(n_pos)) if np.sum(n_pos) > 0 else float("nan")
# Business metrics for each topK
denom_true_pairs = int(y_tau.sum())
for topk in args.topk_list:
topk = int(topk)
idx = topk_indices(s_tau, topk)
captured = np.take_along_axis(y_tau, idx, axis=1)
hits = captured.sum(axis=1).astype(np.float64)
true_cnt = y_tau.sum(axis=1).astype(np.float64)
auc = np.full((K,), np.nan, dtype=np.float64)
min_pos = int(args.min_pos)
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
for k in candidates:
auc[k] = roc_auc_ovr(y_tau_bin[:, k].astype(
np.int32), s_tau[:, k].astype(np.float64))
precision_like = hits / float(min(topk, K))
mean_precision = float(np.mean(precision_like)
) if N > 0 else float("nan")
finite_auc = auc[np.isfinite(auc)]
auc_macro = float(np.mean(finite_auc)
) if finite_auc.size > 0 else float("nan")
w_mask = np.isfinite(auc)
auc_weighted = float(np.sum(auc[w_mask] * n_pos[w_mask]) / np.sum(
n_pos[w_mask])) if np.sum(n_pos[w_mask]) > 0 else float("nan")
n_valid_auc = int(np.isfinite(auc).sum())
mask_has_true = true_cnt > 0
recall_like = np.full((N,), np.nan, dtype=np.float64)
recall_like[mask_has_true] = hits[mask_has_true] / \
true_cnt[mask_has_true]
mean_recall = float(np.nanmean(recall_like)) if np.any(
mask_has_true) else float("nan")
median_recall = float(np.nanmedian(recall_like)) if np.any(
mask_has_true) else float("nan")
numer_captured_pairs = int(captured.sum())
pop_capture_rate = float(
numer_captured_pairs / denom_true_pairs) if denom_true_pairs > 0 else float("nan")
workload_rows.append(
diag_rows.append(
{
"age_bin": age_label,
"tau_years": float(tau),
"topk": int(topk),
"population_capture_rate": pop_capture_rate,
"mean_precision_like": mean_precision,
"mean_recall_like": mean_recall,
"median_recall_like": median_recall,
"denom_true_pairs": denom_true_pairs,
"numer_captured_pairs": numer_captured_pairs,
"n_records": int(n_bin),
"n_causes": int(K),
"auc_macro": auc_macro,
"auc_weighted_by_npos": auc_weighted,
"n_causes_valid_auc": int(n_valid_auc),
"brier_macro": brier_macro,
"brier_weighted_by_npos": brier_weighted,
}
)
per_tau_rows.append(
{
"tau_years": float(tau),
"n_records": int(N),
"n_causes": int(K),
"auc_macro": auc_macro,
"auc_weighted_by_npos": auc_weighted,
"n_causes_valid_auc": int(n_valid_auc),
"brier_macro": brier_macro,
"brier_weighted_by_npos": brier_weighted,
"total_true_pairs": denom_true_pairs,
}
)
diag_per_cause_parts.append(
pd.DataFrame(
{
"age_bin": age_label,
"tau_years": float(tau),
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
"n_neg": n_neg,
"auc": auc,
"brier": brier_per_cause,
}
)
)
out_metrics = os.path.join(run_dir, "horizon_metrics.csv")
out_pc = os.path.join(run_dir, "horizon_per_cause.csv")
out_capture = os.path.join(run_dir, "horizon_capture.csv")
out_wy = os.path.join(run_dir, "workload_yield.csv")
out_diag = os.path.join(run_dir, "horizon_metrics.csv")
out_diag_pc = os.path.join(run_dir, "horizon_per_cause.csv")
pd.DataFrame(per_tau_rows).to_csv(out_metrics, index=False)
if per_cause_rows:
pd.concat(per_cause_rows, ignore_index=True).to_csv(out_pc, index=False)
else:
pd.DataFrame(columns=["tau_years", "cause_id", "n_pos",
"n_neg", "auc", "brier"]).to_csv(out_pc, index=False)
pd.DataFrame(capture_rows).to_csv(out_capture, index=False)
pd.DataFrame(workload_rows).to_csv(out_wy, index=False)
pd.DataFrame(diag_rows).to_csv(out_diag, index=False)
if diag_per_cause_parts:
pd.concat(diag_per_cause_parts, ignore_index=True).to_csv(
out_diag_pc, index=False)
else:
pd.DataFrame(columns=["age_bin", "tau_years", "cause_id", "n_pos",
"n_neg", "auc", "brier"]).to_csv(out_diag_pc, index=False)
print(f"Wrote {out_metrics}")
print(f"Wrote {out_pc}")
print(f"Wrote {out_capture}")
print(f"Wrote {out_wy}")
print(f"Wrote {out_diag} (diagnostic-only)")
print(f"Wrote {out_diag_pc} (diagnostic-only)")
if __name__ == "__main__":

View File

@@ -41,8 +41,8 @@ def parse_args() -> argparse.Namespace:
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
)
p.add_argument(
@@ -80,9 +80,14 @@ def _compute_next_event_metrics(
tau_short: float,
min_pos: int,
) -> tuple[list[dict], pd.DataFrame]:
"""Compute next-event metrics on a given subset.
"""Compute next-event *primary* metrics on a given subset.
Definitions are unchanged from the original script.
Implements 评估方案.md (Next-event):
- score_k = CIF_k(tau_short)
- Hit@K / MRR are computed on records with an observed next-event.
Returns (metrics_rows, diag_df). diag_df is a diagnostic per-cause AUC table
based on whether the cause occurs within (t0, t0+tau_short] (display-only).
"""
n_records_total = int(y_next.size)
eligible = y_next >= 0
@@ -99,49 +104,40 @@ def _compute_next_event_metrics(
{"metric": "tau_short_years", "value": float(tau_short)})
K = int(scores.shape[1])
# Diagnostic: build per-cause AUC using within-window labels.
# This is NOT a primary metric (no IPCW / censoring adjustment).
diag_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.zeros((K,), dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
if n_records_total == 0:
per_cause_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.zeros((K,), dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
return metrics_rows, diag_df
# If no eligible, keep coverage but leave accuracy-like metrics as NaN.
if n_eligible == 0:
per_cause_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.full((K,), n_records_total, dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
return metrics_rows, diag_df
scores_e = scores[eligible]
y_e = y_next[eligible]
pred = scores_e.argmax(axis=1)
acc = float((pred == y_e).mean())
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
metrics_rows.append({"metric": "hitrate_at_1", "value": acc})
# MRR
order = np.argsort(-scores_e, axis=1, kind="mergesort")
@@ -158,17 +154,53 @@ def _compute_next_event_metrics(
metrics_rows.append({"metric": f"hitrate_at_{k}",
"value": float(hit.mean())})
# Macro OvR AUC per cause (optional)
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
# Diagnostic per-cause AUC is computed outside (needs future events), so keep placeholder here.
_ = min_pos
return metrics_rows, diag_df
def _compute_within_window_auc(
*,
scores: np.ndarray,
records: list,
tau_short: float,
min_pos: int,
) -> pd.DataFrame:
"""Diagnostic-only per-cause AUC.
Label definition (event-driven, approximate; no IPCW):
y[i,k]=1 iff at least one event of cause k occurs in (t0, t0+tau_short].
"""
n_records = int(len(records))
if n_records == 0:
return pd.DataFrame(
columns=["cause_id", "n_pos", "n_neg", "auc", "included"],
)
K = int(scores.shape[1])
y = np.zeros((n_records, K), dtype=np.int8)
tau = float(tau_short)
# Build labels from future events.
for i, r in enumerate(records):
if r.future_causes.size == 0:
continue
m = r.future_dt_years <= tau
if not np.any(m):
continue
y[i, r.future_causes[m]] = 1
n_pos = y.sum(axis=0).astype(np.int64)
n_neg = (int(n_records) - n_pos).astype(np.int64)
auc = np.full((K,), np.nan, dtype=np.float64)
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0))
for k in candidates:
auc[k] = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
auc[k] = roc_auc_ovr(y[:, k].astype(np.int32),
scores[:, k].astype(np.float64))
included = (n_pos >= int(min_pos)) & (n_neg > 0)
per_cause_df = pd.DataFrame(
return pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
@@ -178,15 +210,6 @@ def _compute_next_event_metrics(
}
)
aucs = auc[np.isfinite(auc)]
if aucs.size > 0:
metrics_rows.append(
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
else:
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
def main() -> None:
args = parse_args()
@@ -246,19 +269,7 @@ def main() -> None:
)
# Overall (preserve existing output files/shape)
metrics_rows, per_cause_df = _compute_next_event_metrics(
scores=scores,
y_next=y_next,
tau_short=tau,
min_pos=int(args.min_pos),
)
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
per_cause_df.to_csv(out_pc, index=False)
# By age bin (new outputs)
# Strict protocol: evaluate independently per age bin (no mixing).
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
age_bins_days = age_bins_years * DAYS_PER_YEAR
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
@@ -266,7 +277,7 @@ def main() -> None:
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
per_bin_metric_rows: List[dict] = []
per_bin_cause_parts: List[pd.DataFrame] = []
per_bin_auc_parts: List[pd.DataFrame] = []
for b in range(len(age_bins_years) - 1):
lo = float(age_bins_years[b])
hi = float(age_bins_years[b + 1])
@@ -274,6 +285,7 @@ def main() -> None:
m = bin_idx == b
m_scores = scores[m]
m_y = y_next[m]
m_records = [r for r, keep in zip(records, m) if bool(keep)]
m_rows, m_pc = _compute_next_event_metrics(
scores=m_scores,
y_next=m_y,
@@ -282,25 +294,32 @@ def main() -> None:
)
for row in m_rows:
per_bin_metric_rows.append({"age_bin": label, **row})
m_pc = m_pc.copy()
m_pc.insert(0, "age_bin", label)
per_bin_cause_parts.append(m_pc)
m_auc = _compute_within_window_auc(
scores=m_scores,
records=m_records,
tau_short=tau,
min_pos=int(args.min_pos),
)
m_auc.insert(0, "age_bin", label)
m_auc.insert(1, "tau_short_years", float(tau))
per_bin_auc_parts.append(m_auc)
out_metrics_bins = os.path.join(
run_dir, "next_event_metrics_by_age_bin.csv")
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
out_pc_bins = os.path.join(run_dir, "next_event_per_cause_by_age_bin.csv")
if per_bin_cause_parts:
pd.concat(per_bin_cause_parts, ignore_index=True).to_csv(
out_pc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "cause_id", "n_pos", "n_neg",
"auc", "included"]).to_csv(out_pc_bins, index=False)
print(f"Wrote {out_metrics}")
print(f"Wrote {out_pc}")
out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv")
if per_bin_auc_parts:
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
out_auc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "tau_short_years", "cause_id", "n_pos",
"n_neg", "auc", "included"]).to_csv(out_auc_bins, index=False)
print("PRIMARY METRICS: Hit@K / MRR are reported per age bin.")
print("DIAGNOSTICS ONLY: AUC table is event-driven approximate (no IPCW).")
print(f"Wrote {out_metrics_bins}")
print(f"Wrote {out_pc_bins}")
print(f"Wrote {out_auc_bins}")
if __name__ == "__main__":