Enhance Event Rate@K and Recall@K computations with random ranking baseline and additional metrics

This commit is contained in:
2026-01-11 00:52:35 +08:00
parent 4d53f52aa1
commit d8b322cbee

View File

@@ -943,7 +943,11 @@ def compute_event_rate_at_topk_causes(
topk_list: list of K values to evaluate
Returns:
List of rows with: topk, mean, median, n_total.
List of rows with:
- topk
- event_rate_mean / event_rate_median
- recall_mean / recall_median (averaged over individuals with >=1 true cause)
- n_total / n_valid_recall
"""
p = np.asarray(p_tau, dtype=float)
y = np.asarray(y_tau, dtype=float)
@@ -960,7 +964,10 @@ def compute_event_rate_at_topk_causes(
"topk": int(max(1, int(kk))),
"event_rate_mean": float("nan"),
"event_rate_median": float("nan"),
"recall_mean": float("nan"),
"recall_median": float("nan"),
"n_total": int(n),
"n_valid_recall": 0,
}
)
return out
@@ -985,19 +992,140 @@ def compute_event_rate_at_topk_causes(
kk_eff = min(int(kk), int(k_total))
idx = top_sorted[:, :kk_eff]
y_sel = np.take_along_axis(y, idx, axis=1)
# fraction of selected causes that occur
per_person = np.mean(y_sel, axis=1)
# Selected true causes per person
hit = np.sum(y_sel, axis=1)
# Precision-like: fraction of selected causes that occur
per_person = hit / \
float(kk_eff) if kk_eff > 0 else np.full((n,), np.nan)
# Recall@K: fraction of true causes covered by top-K (undefined when no true cause)
g = np.sum(y, axis=1)
valid = g > 0
recall = np.full((n,), np.nan, dtype=float)
recall[valid] = hit[valid] / g[valid]
out_rows.append(
{
"topk": int(kk_eff),
"event_rate_mean": float(np.mean(per_person)) if per_person.size else float("nan"),
"event_rate_median": float(np.median(per_person)) if per_person.size else float("nan"),
"recall_mean": float(np.nanmean(recall)) if int(np.sum(valid)) > 0 else float("nan"),
"recall_median": float(np.nanmedian(recall)) if int(np.sum(valid)) > 0 else float("nan"),
"n_total": int(n),
"n_valid_recall": int(np.sum(valid)),
}
)
return out_rows
def compute_random_ranking_baseline_topk(
y_tau: np.ndarray,
topk_list: Sequence[int],
*,
z: float = 1.645,
) -> List[Dict[str, Any]]:
"""Random ranking baseline for Event Rate@K and Recall@K.
Baseline definition:
- For each individual, pick K causes uniformly at random without replacement.
- EventRate@K = (# selected causes that occur) / K.
- Recall@K = (# selected causes that occur) / (# causes that occur), averaged over individuals with >=1 true cause.
This function computes the expected baseline mean and an approximate 5-95% range
for the population mean using a normal approximation of the hypergeometric variance.
Args:
y_tau: (N, K_total) binary labels
topk_list: K values
z: z-score for the central interval; z=1.645 corresponds to ~90% (5-95%)
Returns:
Rows with baseline means and p05/p95 for both metrics.
"""
y = np.asarray(y_tau, dtype=float)
if y.ndim != 2:
raise ValueError(
"compute_random_ranking_baseline_topk expects y_tau with shape (N,K)")
n, k_total = y.shape
topks = sorted({int(x) for x in topk_list if int(x) > 0})
if not topks:
return []
g = np.sum(y, axis=1) # (N,)
valid = g > 0
n_valid = int(np.sum(valid))
out: List[Dict[str, Any]] = []
for kk in topks:
kk_eff = min(int(kk), int(k_total)) if k_total > 0 else int(kk)
if n == 0 or k_total == 0 or kk_eff <= 0:
out.append(
{
"topk": int(max(1, kk_eff)),
"baseline_event_rate_mean": float("nan"),
"baseline_event_rate_p05": float("nan"),
"baseline_event_rate_p95": float("nan"),
"baseline_recall_mean": float("nan"),
"baseline_recall_p05": float("nan"),
"baseline_recall_p95": float("nan"),
"n_total": int(n),
"n_valid_recall": int(n_valid),
"k_total": int(k_total),
"baseline_method": "random_ranking_hypergeometric_normal_approx",
}
)
continue
# Expected EventRate@K per person is E[X]/K = (K * (g/K_total))/K = g/K_total.
er_mean = float(np.mean(g / float(k_total)))
# Variance of hypergeometric count X:
# Var(X) = K * p * (1-p) * ((K_total - K)/(K_total - 1)), where p=g/K_total.
if k_total > 1 and kk_eff < k_total:
p = g / float(k_total)
finite_corr = (float(k_total - kk_eff) / float(k_total - 1))
var_x = float(kk_eff) * p * (1.0 - p) * finite_corr
else:
var_x = np.zeros_like(g, dtype=float)
var_er = var_x / (float(kk_eff) ** 2)
se_er_mean = float(np.sqrt(np.sum(var_er))) / float(max(1, n))
er_p05 = float(np.clip(er_mean - z * se_er_mean, 0.0, 1.0))
er_p95 = float(np.clip(er_mean + z * se_er_mean, 0.0, 1.0))
# Expected Recall@K for individuals with g>0 is K/K_total (clipped).
rec_mean = float(min(float(kk_eff) / float(k_total), 1.0))
if n_valid > 0:
var_rec = np.zeros_like(g, dtype=float)
gv = g[valid]
var_xv = var_x[valid]
# Var( X / g ) = Var(X) / g^2 (approx; g is fixed per individual)
var_rec_v = var_xv / (gv ** 2)
se_rec_mean = float(np.sqrt(np.sum(var_rec_v))) / float(n_valid)
rec_p05 = float(np.clip(rec_mean - z * se_rec_mean, 0.0, 1.0))
rec_p95 = float(np.clip(rec_mean + z * se_rec_mean, 0.0, 1.0))
else:
rec_p05 = float("nan")
rec_p95 = float("nan")
out.append(
{
"topk": int(kk_eff),
"baseline_event_rate_mean": er_mean,
"baseline_event_rate_p05": er_p05,
"baseline_event_rate_p95": er_p95,
"baseline_recall_mean": rec_mean,
"baseline_recall_p05": float(rec_p05),
"baseline_recall_p95": float(rec_p95),
"n_total": int(n),
"n_valid_recall": int(n_valid),
"k_total": int(k_total),
"baseline_method": "random_ranking_hypergeometric_normal_approx",
}
)
return out
def count_occurs_within_horizon(
loader: DataLoader,
offset_years: float,
@@ -1560,6 +1688,12 @@ def main() -> int:
default=[5, 10, 20, 50],
help="Top-K causes per individual for Event Rate@K (cross-cause prioritization)",
)
ap.add_argument(
"--cause_concentration_write_random_baseline",
action="store_true",
default=False,
help="If set, also export a random-ranking baseline (expected Event Rate@K and Recall@K with an uncertainty range)",
)
args = ap.parse_args()
set_deterministic(args.seed)
@@ -1645,6 +1779,7 @@ def main() -> int:
cap_points_rows: List[Dict[str, Any]] = []
cap_curve_rows: List[Dict[str, Any]] = []
conc_rows: List[Dict[str, Any]] = []
conc_base_rows: List[Dict[str, Any]] = []
# Track per-model integrity status for meta JSON.
integrity_meta: Dict[str, Any] = {}
@@ -1774,6 +1909,21 @@ def main() -> int:
}
)
if bool(args.cause_concentration_write_random_baseline):
for rr in compute_random_ranking_baseline_topk(y_tau_all, topk_causes):
conc_base_rows.append(
{
"model_id": model_id,
"model_type": model_type,
"loss_type": loss_type_id,
"age_encoder": age_encoder,
"cov_type": cov_type,
"horizon": float(tau),
"sex": sex_label,
**rr,
}
)
# Convenience slices for user-facing experiments (focus causes only).
cause_cif_focus = cif_full[:, top_cause_ids, :]
y_within_focus = y_cause_within_tau[:, top_cause_ids, :]
@@ -1965,11 +2115,41 @@ def main() -> int:
"topk",
"event_rate_mean",
"event_rate_median",
"recall_mean",
"recall_median",
"n_total",
"n_valid_recall",
],
conc_rows,
)
if conc_base_rows:
write_simple_csv(
os.path.join(
export_dir, "high_risk_cause_concentration_random_baseline.csv"),
[
"model_id",
"model_type",
"loss_type",
"age_encoder",
"cov_type",
"horizon",
"sex",
"topk",
"baseline_event_rate_mean",
"baseline_event_rate_p05",
"baseline_event_rate_p95",
"baseline_recall_mean",
"baseline_recall_p05",
"baseline_recall_p95",
"n_total",
"n_valid_recall",
"k_total",
"baseline_method",
],
conc_base_rows,
)
# Manifest markdown (stable, user-facing)
manifest_path = os.path.join(export_dir, "eval_exports_manifest.md")
with open(manifest_path, "w", encoding="utf-8") as f:
@@ -1981,7 +2161,8 @@ def main() -> int:
"- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n"
"- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n"
"- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n"
"- high_risk_cause_concentration.csv: Event Rate@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart of Event Rate@K vs K.\n"
"- high_risk_cause_concentration.csv: Event Rate@K and Recall@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart vs K.\n"
"- high_risk_cause_concentration_random_baseline.csv (optional): Random-ranking baseline for Event Rate@K and Recall@K with an uncertainty range (enabled by --cause_concentration_write_random_baseline).\n"
)
meta = {