Refactor AUC computation methods and introduce Event Rate@K for cross-cause prioritization
This commit is contained in:
@@ -4,7 +4,6 @@ import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
@@ -522,7 +521,7 @@ def check_cif_integrity(
|
||||
# Metrics
|
||||
# ============================================================
|
||||
|
||||
# --- Standard fast DeLong AUC variance + CI (ties handled via midranks) ---
|
||||
# --- Rank-based ROC AUC (ties handled via midranks) ---
|
||||
|
||||
def compute_midrank(x: np.ndarray) -> np.ndarray:
|
||||
"""Vectorized midrank computation (ties -> average ranks)."""
|
||||
@@ -554,75 +553,6 @@ def compute_midrank(x: np.ndarray) -> np.ndarray:
|
||||
return out
|
||||
|
||||
|
||||
def fastDeLong(predictions_sorted_transposed: np.ndarray, label_1_count: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Fast DeLong method for computing AUC covariance.
|
||||
|
||||
predictions_sorted_transposed: shape (n_classifiers, n_examples) with positive examples first.
|
||||
"""
|
||||
preds = np.asarray(predictions_sorted_transposed, dtype=float)
|
||||
m = int(label_1_count)
|
||||
n = int(preds.shape[1] - m)
|
||||
if m <= 0 or n <= 0:
|
||||
return np.array([float("nan")]), np.array([[float("nan")]])
|
||||
|
||||
pos = preds[:, :m]
|
||||
neg = preds[:, m:]
|
||||
|
||||
tx = np.array([compute_midrank(x) for x in pos])
|
||||
ty = np.array([compute_midrank(x) for x in neg])
|
||||
tz = np.array([compute_midrank(x) for x in preds])
|
||||
|
||||
aucs = (tz[:, :m].sum(axis=1) - m * (m + 1) / 2.0) / (m * n)
|
||||
|
||||
v01 = (tz[:, :m] - tx) / n
|
||||
v10 = 1.0 - (tz[:, m:] - ty) / m
|
||||
|
||||
if v01.shape[0] > 1:
|
||||
sx = np.cov(v01)
|
||||
sy = np.cov(v10)
|
||||
else:
|
||||
# Single-classifier case: compute row-wise variance (do not flatten).
|
||||
var_v01 = float(np.var(v01, axis=1, ddof=1)[0])
|
||||
var_v10 = float(np.var(v10, axis=1, ddof=1)[0])
|
||||
sx = np.array([[var_v01]])
|
||||
sy = np.array([[var_v10]])
|
||||
delong_cov = sx / m + sy / n
|
||||
return aucs, delong_cov
|
||||
|
||||
|
||||
def calc_auc_variance(ground_truth: np.ndarray, predictions: np.ndarray) -> Tuple[float, float]:
|
||||
y = np.asarray(ground_truth, dtype=int)
|
||||
p = np.asarray(predictions, dtype=float)
|
||||
if y.ndim != 1 or p.ndim != 1 or y.shape[0] != p.shape[0]:
|
||||
raise ValueError("calc_auc_variance expects 1D arrays of equal length")
|
||||
|
||||
m = int(np.sum(y == 1))
|
||||
n = int(np.sum(y == 0))
|
||||
if m == 0 or n == 0:
|
||||
return float("nan"), float("nan")
|
||||
|
||||
order = np.argsort(-y) # positives first
|
||||
preds_sorted = p[order]
|
||||
aucs, cov = fastDeLong(preds_sorted[np.newaxis, :], m)
|
||||
auc = float(aucs[0])
|
||||
var = float(cov[0, 0])
|
||||
return auc, var
|
||||
|
||||
|
||||
def delong_ci(ground_truth: np.ndarray, predictions: np.ndarray, alpha: float = 0.95) -> Tuple[float, float, float]:
|
||||
"""Return (auc, ci_low, ci_high) using DeLong variance and normal CI."""
|
||||
auc, var = calc_auc_variance(ground_truth, predictions)
|
||||
if not np.isfinite(var) or var <= 0:
|
||||
print("WARNING: DeLong variance is non-positive or NaN; CI set to NaN")
|
||||
return float(auc), float("nan"), float("nan")
|
||||
|
||||
sd = math.sqrt(var)
|
||||
z = statistics.NormalDist().inv_cdf(1.0 - (1.0 - float(alpha)) / 2.0)
|
||||
lo = max(0.0, auc - z * sd)
|
||||
hi = min(1.0, auc + z * sd)
|
||||
return float(auc), float(lo), float(hi)
|
||||
|
||||
|
||||
def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||
"""Rank-based ROC AUC via Mann–Whitney U statistic (ties handled by midranks).
|
||||
|
||||
@@ -643,49 +573,6 @@ def roc_auc_rank(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||
return float(auc)
|
||||
|
||||
|
||||
def bootstrap_auc_ci(
|
||||
scores: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
n_bootstrap: int,
|
||||
alpha: float = 0.95,
|
||||
seed: int = 0,
|
||||
) -> Tuple[float, float, float]:
|
||||
"""Bootstrap CI for ROC AUC (percentile)."""
|
||||
rng = np.random.default_rng(int(seed))
|
||||
scores = np.asarray(scores, dtype=float)
|
||||
labels = np.asarray(labels, dtype=int)
|
||||
n = labels.shape[0]
|
||||
if n == 0 or np.all(labels == labels[0]):
|
||||
print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN")
|
||||
return float("nan"), float("nan"), float("nan")
|
||||
|
||||
auc_full = roc_auc_rank(labels, scores)
|
||||
if not np.isfinite(auc_full):
|
||||
print("WARNING: bootstrap AUC CI degenerate labels; CI set to NaN")
|
||||
return float("nan"), float("nan"), float("nan")
|
||||
|
||||
aucs: List[float] = []
|
||||
for _ in range(int(n_bootstrap)):
|
||||
idx = rng.integers(0, n, size=n)
|
||||
yb = labels[idx]
|
||||
if np.all(yb == yb[0]):
|
||||
continue
|
||||
pb = scores[idx]
|
||||
auc = roc_auc_rank(yb, pb)
|
||||
if np.isfinite(auc):
|
||||
aucs.append(float(auc))
|
||||
|
||||
if len(aucs) < 10:
|
||||
print("WARNING: bootstrap AUC CI has too few valid resamples; CI set to NaN")
|
||||
return float(auc_full), float("nan"), float("nan")
|
||||
|
||||
lo_q = (1.0 - float(alpha)) / 2.0
|
||||
hi_q = 1.0 - lo_q
|
||||
lo = float(np.quantile(aucs, lo_q))
|
||||
hi = float(np.quantile(aucs, hi_q))
|
||||
return float(auc_full), lo, hi
|
||||
|
||||
|
||||
def brier_score(p: np.ndarray, y: np.ndarray) -> float:
|
||||
p = np.asarray(p, dtype=float)
|
||||
y = np.asarray(y, dtype=float)
|
||||
@@ -1040,23 +927,75 @@ def compute_capture_points(
|
||||
return rows
|
||||
|
||||
|
||||
def make_horizon_groups(horizons: Sequence[float]) -> Tuple[List[Dict[str, Any]], Dict[float, str], str]:
|
||||
"""Bucketize horizons into short/medium/long using the continuous-horizon rule."""
|
||||
uniq = sorted({float(h) for h in horizons})
|
||||
mapping: Dict[float, str] = {}
|
||||
rows: List[Dict[str, Any]] = []
|
||||
# First 4 short, next 4 medium, rest long.
|
||||
for i, h in enumerate(uniq):
|
||||
if i < 4:
|
||||
g, gr = "short", 1
|
||||
elif i < 8:
|
||||
g, gr = "medium", 2
|
||||
else:
|
||||
g, gr = "long", 3
|
||||
mapping[float(h)] = g
|
||||
rows.append({"horizon": float(h), "group": g, "group_rank": int(gr)})
|
||||
method = "continuous_unique_horizons_first4_next4_rest"
|
||||
return rows, mapping, method
|
||||
def compute_event_rate_at_topk_causes(
|
||||
p_tau: np.ndarray,
|
||||
y_tau: np.ndarray,
|
||||
topk_list: Sequence[int],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Compute Event Rate@K for cross-cause prioritization.
|
||||
|
||||
For each individual, rank causes by predicted risk p_tau at a fixed horizon.
|
||||
For each K, select top-K causes and compute the fraction that occur within the horizon.
|
||||
|
||||
Args:
|
||||
p_tau: (N, K) predicted CIFs at a fixed horizon
|
||||
y_tau: (N, K) binary labels (0/1) whether cause occurs within the horizon
|
||||
topk_list: list of K values to evaluate
|
||||
|
||||
Returns:
|
||||
List of rows with: topk, mean, median, n_total.
|
||||
"""
|
||||
p = np.asarray(p_tau, dtype=float)
|
||||
y = np.asarray(y_tau, dtype=float)
|
||||
if p.ndim != 2 or y.ndim != 2 or p.shape != y.shape:
|
||||
raise ValueError(
|
||||
"compute_event_rate_at_topk_causes expects (N,K) arrays of equal shape")
|
||||
|
||||
n, k_total = p.shape
|
||||
if n == 0 or k_total == 0:
|
||||
out: List[Dict[str, Any]] = []
|
||||
for kk in topk_list:
|
||||
out.append(
|
||||
{
|
||||
"topk": int(max(1, int(kk))),
|
||||
"event_rate_mean": float("nan"),
|
||||
"event_rate_median": float("nan"),
|
||||
"n_total": int(n),
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
# Sanitize K list.
|
||||
topks = sorted({int(x) for x in topk_list if int(x) > 0})
|
||||
if not topks:
|
||||
return []
|
||||
|
||||
max_k = min(int(max(topks)), int(k_total))
|
||||
if max_k <= 0:
|
||||
return []
|
||||
|
||||
# Efficient: get top max_k causes per individual, then sort within those.
|
||||
part = np.argpartition(-p, kth=max_k - 1, axis=1)[:, :max_k] # (N, max_k)
|
||||
p_part = np.take_along_axis(p, part, axis=1)
|
||||
order = np.argsort(-p_part, axis=1)
|
||||
top_sorted = np.take_along_axis(part, order, axis=1) # (N, max_k)
|
||||
|
||||
out_rows: List[Dict[str, Any]] = []
|
||||
for kk in topks:
|
||||
kk_eff = min(int(kk), int(k_total))
|
||||
idx = top_sorted[:, :kk_eff]
|
||||
y_sel = np.take_along_axis(y, idx, axis=1)
|
||||
# fraction of selected causes that occur
|
||||
per_person = np.mean(y_sel, axis=1)
|
||||
out_rows.append(
|
||||
{
|
||||
"topk": int(kk_eff),
|
||||
"event_rate_mean": float(np.mean(per_person)) if per_person.size else float("nan"),
|
||||
"event_rate_median": float(np.median(per_person)) if per_person.size else float("nan"),
|
||||
"n_total": int(n),
|
||||
}
|
||||
)
|
||||
return out_rows
|
||||
|
||||
|
||||
def count_occurs_within_horizon(
|
||||
@@ -1293,8 +1232,6 @@ def evaluate_one_model(
|
||||
out_rows: List[Dict[str, Any]],
|
||||
calib_rows: List[Dict[str, Any]],
|
||||
calib_cause_ids: Optional[Sequence[int]],
|
||||
auc_ci_method: str,
|
||||
bootstrap_n: int,
|
||||
n_calib_bins: int = 10,
|
||||
metric_workers: int = 0,
|
||||
progress: str = "auto",
|
||||
@@ -1375,15 +1312,8 @@ def evaluate_one_model(
|
||||
}
|
||||
)
|
||||
|
||||
# Secondary: discrimination via AUC at the same horizon.
|
||||
if auc_ci_method == "none":
|
||||
auc, lo, hi = float("nan"), float("nan"), float("nan")
|
||||
elif auc_ci_method == "bootstrap":
|
||||
auc, lo, hi = bootstrap_auc_ci(
|
||||
p, y, n_bootstrap=bootstrap_n, alpha=0.95
|
||||
)
|
||||
else:
|
||||
auc, lo, hi = delong_ci(y, p, alpha=0.95)
|
||||
# Secondary: discrimination via AUC at the same horizon (point estimate only).
|
||||
auc = roc_auc_rank(y, p)
|
||||
|
||||
local_rows.append(
|
||||
{
|
||||
@@ -1392,8 +1322,8 @@ def evaluate_one_model(
|
||||
"horizon": float(tau),
|
||||
"cause": int(cid),
|
||||
"value": float(auc),
|
||||
"ci_low": lo,
|
||||
"ci_high": hi,
|
||||
"ci_low": "",
|
||||
"ci_high": "",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1592,15 +1522,6 @@ def main() -> int:
|
||||
ap.add_argument("--integrity_strict", action="store_true", default=False)
|
||||
ap.add_argument("--integrity_tol", type=float, default=1e-6)
|
||||
|
||||
# AUC CI methods
|
||||
ap.add_argument(
|
||||
"--auc_ci_method",
|
||||
type=str,
|
||||
default="delong",
|
||||
choices=["delong", "bootstrap", "none"],
|
||||
)
|
||||
ap.add_argument("--bootstrap_n", type=int, default=2000)
|
||||
|
||||
# Speed/UX
|
||||
ap.add_argument(
|
||||
"--metric_workers",
|
||||
@@ -1630,6 +1551,15 @@ def main() -> int:
|
||||
default=50,
|
||||
help="If >0, also export a dense capture curve for k=1..max_pct",
|
||||
)
|
||||
|
||||
# High-risk cause concentration (cross-cause prioritization)
|
||||
ap.add_argument(
|
||||
"--cause_concentration_topk",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[5, 10, 20, 50],
|
||||
help="Top-K causes per individual for Event Rate@K (cross-cause prioritization)",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
set_deterministic(args.seed)
|
||||
@@ -1708,25 +1638,13 @@ def main() -> int:
|
||||
}
|
||||
)
|
||||
|
||||
# Horizon groups for Experiment 3
|
||||
hg_rows, horizon_to_group, hg_method = make_horizon_groups(
|
||||
args.eval_horizons)
|
||||
write_simple_csv(
|
||||
os.path.join(export_dir, "horizon_groups.csv"),
|
||||
["horizon", "group", "group_rank"],
|
||||
hg_rows,
|
||||
)
|
||||
|
||||
summary_rows: List[Dict[str, Any]] = []
|
||||
calib_rows: List[Dict[str, Any]] = []
|
||||
|
||||
# Experiment exports (accumulated across models)
|
||||
rs_bins_rows: List[Dict[str, Any]] = []
|
||||
rs_sum_rows: List[Dict[str, Any]] = []
|
||||
cap_points_rows: List[Dict[str, Any]] = []
|
||||
cap_curve_rows: List[Dict[str, Any]] = []
|
||||
cal_group_sum_rows: List[Dict[str, Any]] = []
|
||||
cal_group_bins_rows: List[Dict[str, Any]] = []
|
||||
conc_rows: List[Dict[str, Any]] = []
|
||||
|
||||
# Track per-model integrity status for meta JSON.
|
||||
integrity_meta: Dict[str, Any] = {}
|
||||
@@ -1817,8 +1735,6 @@ def main() -> int:
|
||||
out_rows=model_rows,
|
||||
calib_rows=calib_rows,
|
||||
calib_cause_ids=top_cause_ids.tolist(),
|
||||
auc_ci_method=str(args.auc_ci_method),
|
||||
bootstrap_n=int(args.bootstrap_n),
|
||||
metric_workers=int(args.metric_workers),
|
||||
progress=str(args.progress),
|
||||
)
|
||||
@@ -1831,61 +1747,37 @@ def main() -> int:
|
||||
)
|
||||
summary_rows.extend(model_summary_rows)
|
||||
|
||||
# ============================================================
|
||||
# Experiment: High-Risk Cause Concentration at fixed horizon
|
||||
# (cross-cause prioritization accuracy)
|
||||
# ============================================================
|
||||
topk_causes = [int(x) for x in args.cause_concentration_topk]
|
||||
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
|
||||
for h_i, tau in enumerate(args.eval_horizons):
|
||||
p_tau_all = np.asarray(cif_full[:, :, h_i], dtype=float)
|
||||
y_tau_all = np.asarray(
|
||||
y_cause_within_tau[:, :, h_i], dtype=float)
|
||||
if sex_mask is not None:
|
||||
p_tau_all = p_tau_all[sex_mask]
|
||||
y_tau_all = y_tau_all[sex_mask]
|
||||
for rr in compute_event_rate_at_topk_causes(p_tau_all, y_tau_all, topk_causes):
|
||||
conc_rows.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"model_type": model_type,
|
||||
"loss_type": loss_type_id,
|
||||
"age_encoder": age_encoder,
|
||||
"cov_type": cov_type,
|
||||
"horizon": float(tau),
|
||||
"sex": sex_label,
|
||||
**rr,
|
||||
}
|
||||
)
|
||||
|
||||
# Convenience slices for user-facing experiments (focus causes only).
|
||||
cause_cif_focus = cif_full[:, top_cause_ids, :]
|
||||
y_within_focus = y_cause_within_tau[:, top_cause_ids, :]
|
||||
|
||||
# ============================================================
|
||||
# Experiment 1: Risk stratification bins + summary
|
||||
# ============================================================
|
||||
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
|
||||
for h_i, tau in enumerate(args.eval_horizons):
|
||||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||||
p = cause_cif_focus[:, j, h_i]
|
||||
y = y_within_focus[:, j, h_i]
|
||||
if sex_mask is not None:
|
||||
p = p[sex_mask]
|
||||
y = y[sex_mask]
|
||||
q_used, bin_rows, summary = compute_risk_stratification_bins(
|
||||
p, y, q_default=10)
|
||||
for br in bin_rows:
|
||||
rs_bins_rows.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"model_type": model_type,
|
||||
"loss_type": loss_type_id,
|
||||
"age_encoder": age_encoder,
|
||||
"cov_type": cov_type,
|
||||
"cause": int(cause_id),
|
||||
"horizon": float(tau),
|
||||
"sex": sex_label,
|
||||
"q": int(br["q"]),
|
||||
"n_bin": int(br["n_bin"]),
|
||||
"p_mean": _safe_float(br["p_mean"]),
|
||||
"y_rate": _safe_float(br["y_rate"]),
|
||||
"y_overall": _safe_float(br["y_overall"]),
|
||||
"lift_vs_overall": _safe_float(br["lift_vs_overall"]),
|
||||
"q_total": int(q_used),
|
||||
}
|
||||
)
|
||||
rs_sum_rows.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"model_type": model_type,
|
||||
"loss_type": loss_type_id,
|
||||
"age_encoder": age_encoder,
|
||||
"cov_type": cov_type,
|
||||
"cause": int(cause_id),
|
||||
"horizon": float(tau),
|
||||
"sex": sex_label,
|
||||
"q_total": int(q_used),
|
||||
"top_decile_y_rate": _safe_float(summary["top_decile_y_rate"]),
|
||||
"bottom_half_y_rate": _safe_float(summary["bottom_half_y_rate"]),
|
||||
"lift_top10_vs_bottom50": _safe_float(summary["lift_top10_vs_bottom50"]),
|
||||
"slope_pred_vs_obs": _safe_float(summary["slope_pred_vs_obs"]),
|
||||
}
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Experiment 2: High-risk capture points (+ optional curve)
|
||||
# ============================================================
|
||||
@@ -1932,129 +1824,6 @@ def main() -> int:
|
||||
}
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Experiment 3: Short/Medium/Long horizon-group calibration
|
||||
# ============================================================
|
||||
# Per-horizon metrics for grouping
|
||||
# Build a dict for quick access: (cause_id, horizon) -> (brier, ici)
|
||||
per_h: Dict[Tuple[int, float], Dict[str, float]] = {}
|
||||
for rr in model_rows:
|
||||
if rr.get("metric_name") not in {"cause_brier", "cause_ici"}:
|
||||
continue
|
||||
try:
|
||||
cid = int(rr.get("cause"))
|
||||
except Exception:
|
||||
continue
|
||||
if cid not in set(int(x) for x in top_cause_ids.tolist()):
|
||||
continue
|
||||
h = _safe_float(rr.get("horizon"))
|
||||
if not np.isfinite(h):
|
||||
continue
|
||||
key = (cid, float(h))
|
||||
d = per_h.get(key, {})
|
||||
d[str(rr.get("metric_name"))] = _safe_float(rr.get("value"))
|
||||
per_h[key] = d
|
||||
|
||||
# Compute group summaries and pooled bins using the same quantile bins as exp1 (per slice).
|
||||
for sex_label, sex_mask in _sex_slices(sex if sex.size else None):
|
||||
for j, cause_id in enumerate(top_cause_ids.tolist()):
|
||||
# Decide Q per slice for pooled reliability curve
|
||||
n_slice = int(np.sum(sex_mask)) if sex_mask is not None else int(
|
||||
sex.shape[0])
|
||||
q_pool = 10 if n_slice >= 200 else 5
|
||||
|
||||
# Collect per-horizon brier/ici values
|
||||
group_vals: Dict[str, Dict[str, List[float]]] = {"short": {"brier": [], "ici": [
|
||||
]}, "medium": {"brier": [], "ici": []}, "long": {"brier": [], "ici": []}}
|
||||
group_n_total: Dict[str, int] = {
|
||||
"short": 0, "medium": 0, "long": 0}
|
||||
|
||||
# Pooled bins: group -> q -> accumulators
|
||||
pooled: Dict[str, Dict[int, Dict[str, float]]] = {
|
||||
"short": {}, "medium": {}, "long": {}}
|
||||
|
||||
for h_i, tau in enumerate(args.eval_horizons):
|
||||
g = horizon_to_group.get(float(tau), "long")
|
||||
|
||||
# brier/ici per horizon (already computed at full-sample level)
|
||||
d = per_h.get((int(cause_id), float(tau)), {})
|
||||
brier_h = _safe_float(d.get("cause_brier"))
|
||||
ici_h = _safe_float(d.get("cause_ici"))
|
||||
if np.isfinite(brier_h):
|
||||
group_vals[g]["brier"].append(brier_h)
|
||||
if np.isfinite(ici_h):
|
||||
group_vals[g]["ici"].append(ici_h)
|
||||
|
||||
# pooled reliability bins from raw p/y
|
||||
p = cause_cif_focus[:, j, h_i]
|
||||
y = y_within_focus[:, j, h_i]
|
||||
if sex_mask is not None:
|
||||
p = p[sex_mask]
|
||||
y = y[sex_mask]
|
||||
if p.size == 0:
|
||||
continue
|
||||
edges = _quantile_edges(p, q_pool)
|
||||
for qi in range(q_pool):
|
||||
m = (p > edges[qi]) & (p <= edges[qi + 1])
|
||||
nb = int(np.sum(m))
|
||||
if nb == 0:
|
||||
continue
|
||||
pm = float(np.mean(p[m]))
|
||||
yr = float(np.mean(y[m]))
|
||||
acc = pooled[g].get(
|
||||
qi + 1, {"n": 0.0, "p_sum": 0.0, "y_sum": 0.0})
|
||||
acc["n"] += float(nb)
|
||||
acc["p_sum"] += float(nb) * pm
|
||||
acc["y_sum"] += float(nb) * yr
|
||||
pooled[g][qi + 1] = acc
|
||||
group_n_total[g] = max(group_n_total[g], int(p.size))
|
||||
|
||||
for g in ["short", "medium", "long"]:
|
||||
bvals = group_vals[g]["brier"]
|
||||
ivals = group_vals[g]["ici"]
|
||||
cal_group_sum_rows.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"model_type": model_type,
|
||||
"loss_type": loss_type_id,
|
||||
"age_encoder": age_encoder,
|
||||
"cov_type": cov_type,
|
||||
"cause": int(cause_id),
|
||||
"sex": sex_label,
|
||||
"horizon_group": g,
|
||||
"brier_mean": float(np.mean(bvals)) if bvals else float("nan"),
|
||||
"brier_median": float(np.median(bvals)) if bvals else float("nan"),
|
||||
"ici_mean": float(np.mean(ivals)) if ivals else float("nan"),
|
||||
"ici_median": float(np.median(ivals)) if ivals else float("nan"),
|
||||
"n_total": int(group_n_total[g]),
|
||||
"horizon_grouping_method": hg_method,
|
||||
}
|
||||
)
|
||||
|
||||
for qi in range(1, q_pool + 1):
|
||||
acc = pooled[g].get(qi)
|
||||
if not acc or float(acc.get("n", 0.0)) <= 0:
|
||||
continue
|
||||
n_bin = float(acc["n"])
|
||||
cal_group_bins_rows.append(
|
||||
{
|
||||
"model_id": model_id,
|
||||
"model_type": model_type,
|
||||
"loss_type": loss_type_id,
|
||||
"age_encoder": age_encoder,
|
||||
"cov_type": cov_type,
|
||||
"cause": int(cause_id),
|
||||
"sex": sex_label,
|
||||
"horizon_group": g,
|
||||
"q": int(qi),
|
||||
"n_bin": int(n_bin),
|
||||
"p_mean": float(acc["p_sum"] / n_bin),
|
||||
"y_rate": float(acc["y_sum"] / n_bin),
|
||||
"q_total": int(q_pool),
|
||||
"horizon_grouping_method": hg_method,
|
||||
}
|
||||
)
|
||||
|
||||
# Optionally write top-cause counts into the main results CSV as metric rows.
|
||||
for tc in top_causes_meta:
|
||||
model_rows.append(
|
||||
@@ -2141,46 +1910,6 @@ def main() -> int:
|
||||
write_calibration_bins_csv(calib_csv_path, calib_rows)
|
||||
|
||||
# Write experiment exports
|
||||
write_simple_csv(
|
||||
os.path.join(export_dir, "risk_stratification_bins.csv"),
|
||||
[
|
||||
"model_id",
|
||||
"model_type",
|
||||
"loss_type",
|
||||
"age_encoder",
|
||||
"cov_type",
|
||||
"cause",
|
||||
"horizon",
|
||||
"sex",
|
||||
"q",
|
||||
"n_bin",
|
||||
"p_mean",
|
||||
"y_rate",
|
||||
"y_overall",
|
||||
"lift_vs_overall",
|
||||
"q_total",
|
||||
],
|
||||
rs_bins_rows,
|
||||
)
|
||||
write_simple_csv(
|
||||
os.path.join(export_dir, "risk_stratification_summary.csv"),
|
||||
[
|
||||
"model_id",
|
||||
"model_type",
|
||||
"loss_type",
|
||||
"age_encoder",
|
||||
"cov_type",
|
||||
"cause",
|
||||
"horizon",
|
||||
"sex",
|
||||
"q_total",
|
||||
"top_decile_y_rate",
|
||||
"bottom_half_y_rate",
|
||||
"lift_top10_vs_bottom50",
|
||||
"slope_pred_vs_obs",
|
||||
],
|
||||
rs_sum_rows,
|
||||
)
|
||||
write_simple_csv(
|
||||
os.path.join(export_dir, "lift_capture_points.csv"),
|
||||
[
|
||||
@@ -2222,45 +1951,23 @@ def main() -> int:
|
||||
],
|
||||
cap_curve_rows,
|
||||
)
|
||||
|
||||
write_simple_csv(
|
||||
os.path.join(export_dir, "calibration_groups_summary.csv"),
|
||||
os.path.join(export_dir, "high_risk_cause_concentration.csv"),
|
||||
[
|
||||
"model_id",
|
||||
"model_type",
|
||||
"loss_type",
|
||||
"age_encoder",
|
||||
"cov_type",
|
||||
"cause",
|
||||
"horizon",
|
||||
"sex",
|
||||
"horizon_group",
|
||||
"brier_mean",
|
||||
"brier_median",
|
||||
"ici_mean",
|
||||
"ici_median",
|
||||
"topk",
|
||||
"event_rate_mean",
|
||||
"event_rate_median",
|
||||
"n_total",
|
||||
"horizon_grouping_method",
|
||||
],
|
||||
cal_group_sum_rows,
|
||||
)
|
||||
write_simple_csv(
|
||||
os.path.join(export_dir, "calibration_groups_bins.csv"),
|
||||
[
|
||||
"model_id",
|
||||
"model_type",
|
||||
"loss_type",
|
||||
"age_encoder",
|
||||
"cov_type",
|
||||
"cause",
|
||||
"sex",
|
||||
"horizon_group",
|
||||
"q",
|
||||
"n_bin",
|
||||
"p_mean",
|
||||
"y_rate",
|
||||
"q_total",
|
||||
"horizon_grouping_method",
|
||||
],
|
||||
cal_group_bins_rows,
|
||||
conc_rows,
|
||||
)
|
||||
|
||||
# Manifest markdown (stable, user-facing)
|
||||
@@ -2272,13 +1979,9 @@ def main() -> int:
|
||||
"All exports are per-cause and per-horizon unless explicitly aggregated. No all-cause aggregates and no ECE are produced.\n\n"
|
||||
"## Files\n\n"
|
||||
"- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n"
|
||||
"- horizon_groups.csv: Mapping from each horizon to short/medium/long buckets. Intended plot: annotate calibration comparisons.\n"
|
||||
"- risk_stratification_bins.csv: Quantile bins (deciles or quintiles) with predicted vs observed event rates and lift. Intended plot: reliability-by-risk-tier lines.\n"
|
||||
"- risk_stratification_summary.csv: Compact stratification summaries (top decile vs bottom half lift, slope). Intended plot: slide-friendly comparison table.\n"
|
||||
"- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n"
|
||||
"- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n"
|
||||
"- calibration_groups_summary.csv: Short/medium/long aggregated Brier/ICI (mean/median). Intended plot: grouped bar chart by horizon bucket.\n"
|
||||
"- calibration_groups_bins.csv: Pooled reliability points per horizon bucket (weighted by bin size). Intended plot: 3-panel reliability curves per model.\n"
|
||||
"- high_risk_cause_concentration.csv: Event Rate@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart of Event Rate@K vs K.\n"
|
||||
)
|
||||
|
||||
meta = {
|
||||
@@ -2293,12 +1996,11 @@ def main() -> int:
|
||||
"notes": {
|
||||
"label": "Cause-specific, horizon-specific: disease k occurs within tau after context (at least once in (t_ctx, t_ctx+tau])",
|
||||
"primary_metrics": "cause_brier (CIF-based) and cause_ici (calibration)",
|
||||
"secondary_metrics": "cause_auc (discrimination) with optional CI",
|
||||
"secondary_metrics": "cause_auc (discrimination)",
|
||||
"exclusions": "No all-cause aggregation; no next-event formulation; ECE not reported",
|
||||
"warning": "This evaluation does not IPCW-weight censoring because the dataset loader does not expose an explicit censoring time.",
|
||||
"exports_dir": export_dir,
|
||||
"focus_causes": focus_causes,
|
||||
"horizon_grouping_method": hg_method,
|
||||
},
|
||||
}
|
||||
with open(args.out_meta_json, "w") as f:
|
||||
|
||||
Reference in New Issue
Block a user