Refactor AUC computation methods and introduce Event Rate@K for cross-cause prioritization

This commit is contained in:
2026-01-11 00:47:56 +08:00
parent d87752d1f8
commit 4d53f52aa1

View File

@@ -4,7 +4,6 @@ import json
import math
import os
import random
import statistics
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -522,7 +521,7 @@ def check_cif_integrity(
# Metrics
# ============================================================
# --- Standard fast DeLong AUC variance + CI (ties handled via midranks) ---
# --- Rank-based ROC AUC (ties handled via midranks) ---
def compute_midrank(x: np.ndarray) -> np.ndarray:
"""Vectorized midrank computation (ties -> average ranks)."""
@@ -554,75 +553,6 @@ def compute_midrank(x: np.ndarray) -> np.ndarray:
return out
def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]:
"""Fast DeLong method for computing AUC covariance.
predictions_sorted_transposed: shape (n_classifiers, n_examples) with positive examples first.
"""
preds = np.asarray(predictions_sorted_transposed, dtype=float)
m = int(label_1_count)
n = int(preds.shape[1] - m)
if m <= 0 or n <= 0:
return np.array([float("nan")]), np.array([[float("nan")]])
pos = preds[:, :m]
neg = preds[:, m:]
tx = np.array([compute_midrank(x) for x in pos])
ty = np.array([compute_midrank(x) for x in neg])
tz = np.array([compute_midrank(x) for x in preds])
aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n)
v01 = (tz[:, :m] - tx) / n
v10 = 1.0 - (tz[:, m:] - ty) / m
if v01.shape[0] > 1:
sx = np.cov(v01)
sy = np.cov(v10)
else:
# Single-classifier case: compute row-wise variance (do not flatten).
var_v01 = float(np.var(v01, axis=1, ddof=1)[0])
var_v10 = float(np.var(v10, axis=1, ddof=1)[0])
sx = np.array([[var_v01]])
sy = np.array([[var_v10]])
delong_cov = sx / m + sy / n
return aucs, delong_cov
def calc_auc_variance(ground_truth: np.ndarray, predictions: np.ndarray) -> Tuple[float, float]:
y = np.asarray(ground_truth, dtype=int)
p = np.asarray(predictions, dtype=float)
if y.ndim != 1 or p.ndim != 1 or y.shape[0] != p.shape[0]:
raise ValueError("calc_auc_variance expects 1D arrays of equal length")
m = int(np.sum(y == 1))
n = int(np.sum(y == 0))
if m == 0 or n == 0:
return float("nan"), float("nan")
order = np.argsort(-y) # positives first
preds_sorted = p[order]
aucs, cov = fastDeLong(preds_sorted[np.newaxis, :], m)
auc = float(aucs[0])
var = float(cov[0, 0])
return auc, var
def delong_ci(ground_truth: np.ndarray, predictions: np.ndarray, alpha: float = 0.95) -> Tuple[float, float, float]:
"""Return (auc, ci_low, ci_high) using DeLong variance and normal CI."""
auc, var = calc_auc_variance(ground_truth, predictions)
if not np.isfinite(var) or var <= 0:
print("WARNING: DeLong variance is non-positive or NaN; CI set to NaN")
return float(auc), float("nan"), float("nan")
sd = math.sqrt(var)
z = statistics.NormalDist().inv_cdf(1.0 - (1.0 - float(alpha)) / 2.0)
lo = max(0.0, auc - z * sd)
hi = min(1.0, auc + z * sd)
return float(auc), float(lo), float(hi)
def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float:
"""Rank-based ROC AUC via MannWhitney U statistic (ties handled by midranks).
@@ -643,49 +573,6 @@ def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float:
return float(auc)
def bootstrap_auc_ci(
scores: np.ndarray,
labels: np.ndarray,
n_bootstrap: int,
alpha: float = 0.95,
seed: int = 0,
) -> Tuple[float, float, float]:
"""Bootstrap CI for ROC AUC (percentile)."""
rng = np.random.default_rng(int(seed))
scores = np.asarray(scores, dtype=float)
labels = np.asarray(labels, dtype=int)
n = labels.shape[0]
if n == 0 or np.all(labels == labels[0]):
print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN")
return float("nan"), float("nan"), float("nan")
auc_full = roc_auc_rank(labels, scores)
if not np.isfinite(auc_full):
print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN")
return float("nan"), float("nan"), float("nan")
aucs: List[float] = []
for _ in range(int(n_bootstrap)):
idx = rng.integers(0, n, size=n)
yb = labels[idx]
if np.all(yb == yb[0]):
continue
pb = scores[idx]
auc = roc_auc_rank(yb, pb)
if np.isfinite(auc):
aucs.append(float(auc))
if len(aucs) < 10:
print("WARNING: bootstrap AUC CI has too few valid resamples; CI set to NaN")
return float(auc_full), float("nan"), float("nan")
lo_q = (1.0 - float(alpha)) / 2.0
hi_q = 1.0 - lo_q
lo = float(np.quantile(aucs, lo_q))
hi = float(np.quantile(aucs, hi_q))
return float(auc_full), lo, hi
def brier_score(p: np.ndarray, y: np.ndarray) -> float:
p = np.asarray(p, dtype=float)
y = np.asarray(y, dtype=float)
@@ -1040,23 +927,75 @@ def compute_capture_points(
return rows
def make_horizon_groups(horizons: Sequence[float]) -> Tuple[List[Dict[str, Any]], Dict[float, str], str]:
"""Bucketize horizons into short/medium/long using the continuous-horizon rule."""
uniq = sorted({float(h) for h in horizons})
mapping: Dict[float, str] = {}
rows: List[Dict[str, Any]] = []
# First 4 short, next 4 medium, rest long.
for i, h in enumerate(uniq):
if i < 4:
g, gr = "short", 1
elif i < 8:
g, gr = "medium", 2
else:
g, gr = "long", 3
mapping[float(h)] = g
rows.append({"horizon": float(h), "group": g, "group_rank": int(gr)})
method = "continuous_unique_horizons_first4_next4_rest"
return rows, mapping, method
def compute_event_rate_at_topk_causes(
p_tau: np.ndarray,
y_tau: np.ndarray,
topk_list: Sequence[int],
) -> List[Dict[str, Any]]:
"""Compute Event Rate@K for cross-cause prioritization.
For each individual, rank causes by predicted risk p_tau at a fixed horizon.
For each K, select top-K causes and compute the fraction that occur within the horizon.
Args:
p_tau: (N, K) predicted CIFs at a fixed horizon
y_tau: (N, K) binary labels (0/1) whether cause occurs within the horizon
topk_list: list of K values to evaluate
Returns:
List of rows with: topk, mean, median, n_total.
"""
p = np.asarray(p_tau, dtype=float)
y = np.asarray(y_tau, dtype=float)
if p.ndim != 2 or y.ndim != 2 or p.shape != y.shape:
raise ValueError(
"compute_event_rate_at_topk_causes expects (N,K) arrays of equal shape")
n, k_total = p.shape
if n == 0 or k_total == 0:
out: List[Dict[str, Any]] = []
for kk in topk_list:
out.append(
{
"topk": int(max(1, int(kk))),
"event_rate_mean": float("nan"),
"event_rate_median": float("nan"),
"n_total": int(n),
}
)
return out
# Sanitize K list.
topks = sorted({int(x) for x in topk_list if int(x) > 0})
if not topks:
return []
max_k = min(int(max(topks)), int(k_total))
if max_k <= 0:
return []
# Efficient: get top max_k causes per individual, then sort within those.
part = np.argpartition(-p, kth=max_k - 1, axis=1)[:, :max_k] # (N, max_k)
p_part = np.take_along_axis(p, part, axis=1)
order = np.argsort(-p_part, axis=1)
top_sorted = np.take_along_axis(part, order, axis=1) # (N, max_k)
out_rows: List[Dict[str, Any]] = []
for kk in topks:
kk_eff = min(int(kk), int(k_total))
idx = top_sorted[:, :kk_eff]
y_sel = np.take_along_axis(y, idx, axis=1)
# fraction of selected causes that occur
per_person = np.mean(y_sel, axis=1)
out_rows.append(
{
"topk": int(kk_eff),
"event_rate_mean": float(np.mean(per_person)) if per_person.size else float("nan"),
"event_rate_median": float(np.median(per_person)) if per_person.size else float("nan"),
"n_total": int(n),
}
)
return out_rows
def count_occurs_within_horizon(
@@ -1293,8 +1232,6 @@ def evaluate_one_model(
out_rows: List[Dict[str, Any]],
calib_rows: List[Dict[str, Any]],
calib_cause_ids: Optional[Sequence[int]],
auc_ci_method: str,
bootstrap_n: int,
n_calib_bins: int = 10,
metric_workers: int = 0,
progress: str = "auto",
@@ -1375,15 +1312,8 @@ def evaluate_one_model(
}
)
# Secondary: discrimination via AUC at the same horizon.
if auc_ci_method == "none":
auc, lo, hi = float("nan"), float("nan"), float("nan")
elif auc_ci_method == "bootstrap":
auc, lo, hi = bootstrap_auc_ci(
p, y, n_bootstrap=bootstrap_n, alpha=0.95
)
else:
auc, lo, hi = delong_ci(y, p, alpha=0.95)
# Secondary: discrimination via AUC at the same horizon (point estimate only).
auc = roc_auc_rank(y, p)
local_rows.append(
{
@@ -1392,8 +1322,8 @@ def evaluate_one_model(
"horizon": float(tau),
"cause": int(cid),
"value": float(auc),
"ci_low": lo,
"ci_high": hi,
"ci_low": "",
"ci_high": "",
}
)
@@ -1592,15 +1522,6 @@ def main() -> int:
ap.add_argument("--integrity_strict", action="store_true", default=False)
ap.add_argument("--integrity_tol", type=float, default=1e-6)
# AUC CI methods
ap.add_argument(
"--auc_ci_method",
type=str,
default="delong",
choices=["delong", "bootstrap", "none"],
)
ap.add_argument("--bootstrap_n", type=int, default=2000)
# Speed/UX
ap.add_argument(
"--metric_workers",
@@ -1630,6 +1551,15 @@ def main() -> int:
default=50,
help="If >0, also export a dense capture curve for k=1..max_pct",
)
# High-risk cause concentration (cross-cause prioritization)
ap.add_argument(
"--cause_concentration_topk",
type=int,
nargs="*",
default=[5, 10, 20, 50],
help="Top-K causes per individual for Event Rate@K (cross-cause prioritization)",
)
args = ap.parse_args()
set_deterministic(args.seed)
@@ -1708,25 +1638,13 @@ def main() -> int:
}
)
# Horizon groups for Experiment 3
hg_rows, horizon_to_group, hg_method = make_horizon_groups(
args.eval_horizons)
write_simple_csv(
os.path.join(export_dir, "horizon_groups.csv"),
["horizon", "group", "group_rank"],
hg_rows,
)
summary_rows: List[Dict[str, Any]] = []
calib_rows: List[Dict[str, Any]] = []
# Experiment exports (accumulated across models)
rs_bins_rows: List[Dict[str, Any]] = []
rs_sum_rows: List[Dict[str, Any]] = []
cap_points_rows: List[Dict[str, Any]] = []
cap_curve_rows: List[Dict[str, Any]] = []
cal_group_sum_rows: List[Dict[str, Any]] = []
cal_group_bins_rows: List[Dict[str, Any]] = []
conc_rows: List[Dict[str, Any]] = []
# Track per-model integrity status for meta JSON.
integrity_meta: Dict[str, Any] = {}
@@ -1817,8 +1735,6 @@ def main() -> int:
out_rows=model_rows,
calib_rows=calib_rows,
calib_cause_ids=top_cause_ids.tolist(),
auc_ci_method=str(args.auc_ci_method),
bootstrap_n=int(args.bootstrap_n),
metric_workers=int(args.metric_workers),
progress=str(args.progress),
)
@@ -1831,61 +1747,37 @@ def main() -> int:
)
summary_rows.extend(model_summary_rows)
# Convenience slices for user-facing experiments (focus causes only).
cause_cif_focus = cif_full[:, top_cause_ids, :]
y_within_focus = y_cause_within_tau[:, top_cause_ids, :]
# ============================================================
# Experiment 1: Risk stratification bins + summary
# Experiment: High-Risk Cause Concentration at fixed horizon
# (cross-cause prioritization accuracy)
# ============================================================
topk_causes = [int(x) for x in args.cause_concentration_topk]
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
for h_i, tau in enumerate(args.eval_horizons):
for j, cause_id in enumerate(top_cause_ids.tolist()):
p = cause_cif_focus[:, j, h_i]
y = y_within_focus[:, j, h_i]
if sex_mask is not None:
p = p[sex_mask]
y = y[sex_mask]
q_used, bin_rows, summary = compute_risk_stratification_bins(
p, y, q_default=10)
for br in bin_rows:
rs_bins_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"horizon": float(tau),
"sex": sex_label,
"q": int(br["q"]),
"n_bin": int(br["n_bin"]),
"p_mean": _safe_float(br["p_mean"]),
"y_rate": _safe_float(br["y_rate"]),
"y_overall": _safe_float(br["y_overall"]),
"lift_vs_overall": _safe_float(br["lift_vs_overall"]),
"q_total": int(q_used),
}
)
rs_sum_rows.append(
p_tau_all = np.asarray(cif_full[:, :, h_i], dtype=float)
y_tau_all = np.asarray(
y_cause_within_tau[:, :, h_i], dtype=float)
if sex_mask is not None:
p_tau_all = p_tau_all[sex_mask]
y_tau_all = y_tau_all[sex_mask]
for rr in compute_event_rate_at_topk_causes(p_tau_all, y_tau_all, topk_causes):
conc_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"horizon": float(tau),
"sex": sex_label,
"q_total": int(q_used),
"top_decile_y_rate": _safe_float(summary["top_decile_y_rate"]),
"bottom_half_y_rate": _safe_float(summary["bottom_half_y_rate"]),
"lift_top10_vs_bottom50": _safe_float(summary["lift_top10_vs_bottom50"]),
"slope_pred_vs_obs": _safe_float(summary["slope_pred_vs_obs"]),
**rr,
}
)
# Convenience slices for user-facing experiments (focus causes only).
cause_cif_focus = cif_full[:, top_cause_ids, :]
y_within_focus = y_cause_within_tau[:, top_cause_ids, :]
# ============================================================
# Experiment 2: High-risk capture points (+ optional curve)
# ============================================================
@@ -1932,129 +1824,6 @@ def main() -> int:
}
)
# ============================================================
# Experiment 3: Short/Medium/Long horizon-group calibration
# ============================================================
# Per-horizon metrics for grouping
# Build a dict for quick access: (cause_id, horizon) -> (brier, ici)
per_h: Dict[Tuple[int, float], Dict[str, float]] = {}
for rr in model_rows:
if rr.get("metric_name") not in {"cause_brier", "cause_ici"}:
continue
try:
cid = int(rr.get("cause"))
except Exception:
continue
if cid not in set(int(x) for x in top_cause_ids.tolist()):
continue
h = _safe_float(rr.get("horizon"))
if not np.isfinite(h):
continue
key = (cid, float(h))
d = per_h.get(key, {})
d[str(rr.get("metric_name"))] = _safe_float(rr.get("value"))
per_h[key] = d
# Compute group summaries and pooled bins using the same quantile bins as exp1 (per slice).
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
for j, cause_id in enumerate(top_cause_ids.tolist()):
# Decide Q per slice for pooled reliability curve
n_slice = int(np.sum(sex_mask)) if sex_mask is not None else int(
sex.shape[0])
q_pool = 10 if n_slice >= 200 else 5
# Collect per-horizon brier/ici values
group_vals: Dict[str, Dict[str, List[float]]] = {"short": {"brier": [], "ici": [
]}, "medium": {"brier": [], "ici": []}, "long": {"brier": [], "ici": []}}
group_n_total: Dict[str, int] = {
"short": 0, "medium": 0, "long": 0}
# Pooled bins: group -> q -> accumulators
pooled: Dict[str, Dict[int, Dict[str, float]]] = {
"short": {}, "medium": {}, "long": {}}
for h_i, tau in enumerate(args.eval_horizons):
g = horizon_to_group.get(float(tau), "long")
# brier/ici per horizon (already computed at full-sample level)
d = per_h.get((int(cause_id), float(tau)), {})
brier_h = _safe_float(d.get("cause_brier"))
ici_h = _safe_float(d.get("cause_ici"))
if np.isfinite(brier_h):
group_vals[g]["brier"].append(brier_h)
if np.isfinite(ici_h):
group_vals[g]["ici"].append(ici_h)
# pooled reliability bins from raw p/y
p = cause_cif_focus[:, j, h_i]
y = y_within_focus[:, j, h_i]
if sex_mask is not None:
p = p[sex_mask]
y = y[sex_mask]
if p.size == 0:
continue
edges = _quantile_edges(p, q_pool)
for qi in range(q_pool):
m = (p > edges[qi]) & (p <= edges[qi + 1])
nb = int(np.sum(m))
if nb == 0:
continue
pm = float(np.mean(p[m]))
yr = float(np.mean(y[m]))
acc = pooled[g].get(
qi + 1, {"n": 0.0, "p_sum": 0.0, "y_sum": 0.0})
acc["n"] += float(nb)
acc["p_sum"] += float(nb) * pm
acc["y_sum"] += float(nb) * yr
pooled[g][qi + 1] = acc
group_n_total[g] = max(group_n_total[g], int(p.size))
for g in ["short", "medium", "long"]:
bvals = group_vals[g]["brier"]
ivals = group_vals[g]["ici"]
cal_group_sum_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"sex": sex_label,
"horizon_group": g,
"brier_mean": float(np.mean(bvals)) if bvals else float("nan"),
"brier_median": float(np.median(bvals)) if bvals else float("nan"),
"ici_mean": float(np.mean(ivals)) if ivals else float("nan"),
"ici_median": float(np.median(ivals)) if ivals else float("nan"),
"n_total": int(group_n_total[g]),
"horizon_grouping_method": hg_method,
}
)
for qi in range(1, q_pool + 1):
acc = pooled[g].get(qi)
if not acc or float(acc.get("n", 0.0)) <= 0:
continue
n_bin = float(acc["n"])
cal_group_bins_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"cause": int(cause_id),
"sex": sex_label,
"horizon_group": g,
"q": int(qi),
"n_bin": int(n_bin),
"p_mean": float(acc["p_sum"] / n_bin),
"y_rate": float(acc["y_sum"] / n_bin),
"q_total": int(q_pool),
"horizon_grouping_method": hg_method,
}
)
# Optionally write top-cause counts into the main results CSV as metric rows.
for tc in top_causes_meta:
model_rows.append(
@@ -2141,46 +1910,6 @@ def main() -> int:
write_calibration_bins_csv(calib_csv_path, calib_rows)
# Write experiment exports
write_simple_csv(
os.path.join(export_dir, "risk_stratification_bins.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"horizon",
"sex",
"q",
"n_bin",
"p_mean",
"y_rate",
"y_overall",
"lift_vs_overall",
"q_total",
],
rs_bins_rows,
)
write_simple_csv(
os.path.join(export_dir, "risk_stratification_summary.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"horizon",
"sex",
"q_total",
"top_decile_y_rate",
"bottom_half_y_rate",
"lift_top10_vs_bottom50",
"slope_pred_vs_obs",
],
rs_sum_rows,
)
write_simple_csv(
os.path.join(export_dir, "lift_capture_points.csv"),
[
@@ -2222,45 +1951,23 @@ def main() -> int:
],
cap_curve_rows,
)
write_simple_csv(
os.path.join(export_dir, "calibration_groups_summary.csv"),
os.path.join(export_dir, "high_risk_cause_concentration.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"horizon",
"sex",
"horizon_group",
"brier_mean",
"brier_median",
"ici_mean",
"ici_median",
"topk",
"event_rate_mean",
"event_rate_median",
"n_total",
"horizon_grouping_method",
],
cal_group_sum_rows,
)
write_simple_csv(
os.path.join(export_dir, "calibration_groups_bins.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"cause",
"sex",
"horizon_group",
"q",
"n_bin",
"p_mean",
"y_rate",
"q_total",
"horizon_grouping_method",
],
cal_group_bins_rows,
conc_rows,
)
# Manifest markdown (stable, user-facing)
@@ -2272,13 +1979,9 @@ def main() -> int:
"All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n"
"## Files\n\n"
"- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n"
"- horizon_groups.csv: Mapping from each horizon to short/medium/long buckets. Intended plot: annotate calibration comparisons.\n"
"- risk_stratification_bins.csv: Quantile bins (deciles or quintiles) with predicted vs observed event rates and lift. Intended plot: reliability-by-risk-tier lines.\n"
"- risk_stratification_summary.csv: Compact stratification summaries (top decile vs bottom half lift, slope). Intended plot: slide-friendly comparison table.\n"
"- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n"
"- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n"
"- calibration_groups_summary.csv: Short/medium/long aggregated Brier/ICI (mean/median). Intended plot: grouped bar chart by horizon bucket.\n"
"- calibration_groups_bins.csv: Pooled reliability points per horizon bucket (weighted by bin size). Intended plot: 3-panel reliability curves per model.\n"
"- high_risk_cause_concentration.csv: Event Rate@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart of Event Rate@K vs K.\n"
)
meta = {
@@ -2293,12 +1996,11 @@ def main() -> int:
"notes": {
"label": "Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau])",
"primary_metrics": "cause_brier (CIF-based) and cause_ici (calibration)",
"secondary_metrics": "cause_auc (discrimination) with optional CI",
"secondary_metrics": "cause_auc (discrimination)",
"exclusions": "No all-cause aggregation; no next-event formulation; ECE not reported",
"warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.",
"exports_dir": export_dir,
"focus_causes": focus_causes,
"horizon_grouping_method": hg_method,
},
}
with open(args.out_meta_json, "w") as f: