Enhance Event Rate@K and Recall@K computations with random ranking baseline and additional metrics
This commit is contained in:
@@ -943,7 +943,11 @@ def compute_event_rate_at_topk_causes(
|
|||||||
topk_list: list of K values to evaluate
|
topk_list: list of K values to evaluate
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of rows with: topk, mean, median, n_total.
|
List of rows with:
|
||||||
|
- topk
|
||||||
|
- event_rate_mean / event_rate_median
|
||||||
|
- recall_mean / recall_median (averaged over individuals with >=1 true cause)
|
||||||
|
- n_total / n_valid_recall
|
||||||
"""
|
"""
|
||||||
p = np.asarray(p_tau, dtype=float)
|
p = np.asarray(p_tau, dtype=float)
|
||||||
y = np.asarray(y_tau, dtype=float)
|
y = np.asarray(y_tau, dtype=float)
|
||||||
@@ -960,7 +964,10 @@ def compute_event_rate_at_topk_causes(
|
|||||||
"topk": int(max(1, int(kk))),
|
"topk": int(max(1, int(kk))),
|
||||||
"event_rate_mean": float("nan"),
|
"event_rate_mean": float("nan"),
|
||||||
"event_rate_median": float("nan"),
|
"event_rate_median": float("nan"),
|
||||||
|
"recall_mean": float("nan"),
|
||||||
|
"recall_median": float("nan"),
|
||||||
"n_total": int(n),
|
"n_total": int(n),
|
||||||
|
"n_valid_recall": 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
@@ -985,19 +992,140 @@ def compute_event_rate_at_topk_causes(
|
|||||||
kk_eff = min(int(kk), int(k_total))
|
kk_eff = min(int(kk), int(k_total))
|
||||||
idx = top_sorted[:, :kk_eff]
|
idx = top_sorted[:, :kk_eff]
|
||||||
y_sel = np.take_along_axis(y, idx, axis=1)
|
y_sel = np.take_along_axis(y, idx, axis=1)
|
||||||
# fraction of selected causes that occur
|
# Selected true causes per person
|
||||||
per_person = np.mean(y_sel, axis=1)
|
hit = np.sum(y_sel, axis=1)
|
||||||
|
# Precision-like: fraction of selected causes that occur
|
||||||
|
per_person = hit / \
|
||||||
|
float(kk_eff) if kk_eff > 0 else np.full((n,), np.nan)
|
||||||
|
|
||||||
|
# Recall@K: fraction of true causes covered by top-K (undefined when no true cause)
|
||||||
|
g = np.sum(y, axis=1)
|
||||||
|
valid = g > 0
|
||||||
|
recall = np.full((n,), np.nan, dtype=float)
|
||||||
|
recall[valid] = hit[valid] / g[valid]
|
||||||
out_rows.append(
|
out_rows.append(
|
||||||
{
|
{
|
||||||
"topk": int(kk_eff),
|
"topk": int(kk_eff),
|
||||||
"event_rate_mean": float(np.mean(per_person)) if per_person.size else float("nan"),
|
"event_rate_mean": float(np.mean(per_person)) if per_person.size else float("nan"),
|
||||||
"event_rate_median": float(np.median(per_person)) if per_person.size else float("nan"),
|
"event_rate_median": float(np.median(per_person)) if per_person.size else float("nan"),
|
||||||
|
"recall_mean": float(np.nanmean(recall)) if int(np.sum(valid)) > 0 else float("nan"),
|
||||||
|
"recall_median": float(np.nanmedian(recall)) if int(np.sum(valid)) > 0 else float("nan"),
|
||||||
"n_total": int(n),
|
"n_total": int(n),
|
||||||
|
"n_valid_recall": int(np.sum(valid)),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return out_rows
|
return out_rows
|
||||||
|
|
||||||
|
|
||||||
|
def compute_random_ranking_baseline_topk(
|
||||||
|
y_tau: np.ndarray,
|
||||||
|
topk_list: Sequence[int],
|
||||||
|
*,
|
||||||
|
z: float = 1.645,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Random ranking baseline for Event Rate@K and Recall@K.
|
||||||
|
|
||||||
|
Baseline definition:
|
||||||
|
- For each individual, pick K causes uniformly at random without replacement.
|
||||||
|
- EventRate@K = (# selected causes that occur) / K.
|
||||||
|
- Recall@K = (# selected causes that occur) / (# causes that occur), averaged over individuals with >=1 true cause.
|
||||||
|
|
||||||
|
This function computes the expected baseline mean and an approximate 5-95% range
|
||||||
|
for the population mean using a normal approximation of the hypergeometric variance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_tau: (N, K_total) binary labels
|
||||||
|
topk_list: K values
|
||||||
|
z: z-score for the central interval; z=1.645 corresponds to ~90% (5-95%)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rows with baseline means and p05/p95 for both metrics.
|
||||||
|
"""
|
||||||
|
y = np.asarray(y_tau, dtype=float)
|
||||||
|
if y.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"compute_random_ranking_baseline_topk expects y_tau with shape (N,K)")
|
||||||
|
|
||||||
|
n, k_total = y.shape
|
||||||
|
topks = sorted({int(x) for x in topk_list if int(x) > 0})
|
||||||
|
if not topks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
g = np.sum(y, axis=1) # (N,)
|
||||||
|
valid = g > 0
|
||||||
|
n_valid = int(np.sum(valid))
|
||||||
|
|
||||||
|
out: List[Dict[str, Any]] = []
|
||||||
|
for kk in topks:
|
||||||
|
kk_eff = min(int(kk), int(k_total)) if k_total > 0 else int(kk)
|
||||||
|
if n == 0 or k_total == 0 or kk_eff <= 0:
|
||||||
|
out.append(
|
||||||
|
{
|
||||||
|
"topk": int(max(1, kk_eff)),
|
||||||
|
"baseline_event_rate_mean": float("nan"),
|
||||||
|
"baseline_event_rate_p05": float("nan"),
|
||||||
|
"baseline_event_rate_p95": float("nan"),
|
||||||
|
"baseline_recall_mean": float("nan"),
|
||||||
|
"baseline_recall_p05": float("nan"),
|
||||||
|
"baseline_recall_p95": float("nan"),
|
||||||
|
"n_total": int(n),
|
||||||
|
"n_valid_recall": int(n_valid),
|
||||||
|
"k_total": int(k_total),
|
||||||
|
"baseline_method": "random_ranking_hypergeometric_normal_approx",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Expected EventRate@K per person is E[X]/K = (K * (g/K_total))/K = g/K_total.
|
||||||
|
er_mean = float(np.mean(g / float(k_total)))
|
||||||
|
|
||||||
|
# Variance of hypergeometric count X:
|
||||||
|
# Var(X) = K * p * (1-p) * ((K_total - K)/(K_total - 1)), where p=g/K_total.
|
||||||
|
if k_total > 1 and kk_eff < k_total:
|
||||||
|
p = g / float(k_total)
|
||||||
|
finite_corr = (float(k_total - kk_eff) / float(k_total - 1))
|
||||||
|
var_x = float(kk_eff) * p * (1.0 - p) * finite_corr
|
||||||
|
else:
|
||||||
|
var_x = np.zeros_like(g, dtype=float)
|
||||||
|
|
||||||
|
var_er = var_x / (float(kk_eff) ** 2)
|
||||||
|
se_er_mean = float(np.sqrt(np.sum(var_er))) / float(max(1, n))
|
||||||
|
er_p05 = float(np.clip(er_mean - z * se_er_mean, 0.0, 1.0))
|
||||||
|
er_p95 = float(np.clip(er_mean + z * se_er_mean, 0.0, 1.0))
|
||||||
|
|
||||||
|
# Expected Recall@K for individuals with g>0 is K/K_total (clipped).
|
||||||
|
rec_mean = float(min(float(kk_eff) / float(k_total), 1.0))
|
||||||
|
if n_valid > 0:
|
||||||
|
var_rec = np.zeros_like(g, dtype=float)
|
||||||
|
gv = g[valid]
|
||||||
|
var_xv = var_x[valid]
|
||||||
|
# Var( X / g ) = Var(X) / g^2 (approx; g is fixed per individual)
|
||||||
|
var_rec_v = var_xv / (gv ** 2)
|
||||||
|
se_rec_mean = float(np.sqrt(np.sum(var_rec_v))) / float(n_valid)
|
||||||
|
rec_p05 = float(np.clip(rec_mean - z * se_rec_mean, 0.0, 1.0))
|
||||||
|
rec_p95 = float(np.clip(rec_mean + z * se_rec_mean, 0.0, 1.0))
|
||||||
|
else:
|
||||||
|
rec_p05 = float("nan")
|
||||||
|
rec_p95 = float("nan")
|
||||||
|
|
||||||
|
out.append(
|
||||||
|
{
|
||||||
|
"topk": int(kk_eff),
|
||||||
|
"baseline_event_rate_mean": er_mean,
|
||||||
|
"baseline_event_rate_p05": er_p05,
|
||||||
|
"baseline_event_rate_p95": er_p95,
|
||||||
|
"baseline_recall_mean": rec_mean,
|
||||||
|
"baseline_recall_p05": float(rec_p05),
|
||||||
|
"baseline_recall_p95": float(rec_p95),
|
||||||
|
"n_total": int(n),
|
||||||
|
"n_valid_recall": int(n_valid),
|
||||||
|
"k_total": int(k_total),
|
||||||
|
"baseline_method": "random_ranking_hypergeometric_normal_approx",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def count_occurs_within_horizon(
|
def count_occurs_within_horizon(
|
||||||
loader: DataLoader,
|
loader: DataLoader,
|
||||||
offset_years: float,
|
offset_years: float,
|
||||||
@@ -1560,6 +1688,12 @@ def main() -> int:
|
|||||||
default=[5, 10, 20, 50],
|
default=[5, 10, 20, 50],
|
||||||
help="Top-K causes per individual for Event Rate@K (cross-cause prioritization)",
|
help="Top-K causes per individual for Event Rate@K (cross-cause prioritization)",
|
||||||
)
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--cause_concentration_write_random_baseline",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="If set, also export a random-ranking baseline (expected Event Rate@K and Recall@K with an uncertainty range)",
|
||||||
|
)
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
set_deterministic(args.seed)
|
set_deterministic(args.seed)
|
||||||
@@ -1645,6 +1779,7 @@ def main() -> int:
|
|||||||
cap_points_rows: List[Dict[str, Any]] = []
|
cap_points_rows: List[Dict[str, Any]] = []
|
||||||
cap_curve_rows: List[Dict[str, Any]] = []
|
cap_curve_rows: List[Dict[str, Any]] = []
|
||||||
conc_rows: List[Dict[str, Any]] = []
|
conc_rows: List[Dict[str, Any]] = []
|
||||||
|
conc_base_rows: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
# Track per-model integrity status for meta JSON.
|
# Track per-model integrity status for meta JSON.
|
||||||
integrity_meta: Dict[str, Any] = {}
|
integrity_meta: Dict[str, Any] = {}
|
||||||
@@ -1774,6 +1909,21 @@ def main() -> int:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if bool(args.cause_concentration_write_random_baseline):
|
||||||
|
for rr in compute_random_ranking_baseline_topk(y_tau_all, topk_causes):
|
||||||
|
conc_base_rows.append(
|
||||||
|
{
|
||||||
|
"model_id": model_id,
|
||||||
|
"model_type": model_type,
|
||||||
|
"loss_type": loss_type_id,
|
||||||
|
"age_encoder": age_encoder,
|
||||||
|
"cov_type": cov_type,
|
||||||
|
"horizon": float(tau),
|
||||||
|
"sex": sex_label,
|
||||||
|
**rr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Convenience slices for user-facing experiments (focus causes only).
|
# Convenience slices for user-facing experiments (focus causes only).
|
||||||
cause_cif_focus = cif_full[:, top_cause_ids, :]
|
cause_cif_focus = cif_full[:, top_cause_ids, :]
|
||||||
y_within_focus = y_cause_within_tau[:, top_cause_ids, :]
|
y_within_focus = y_cause_within_tau[:, top_cause_ids, :]
|
||||||
@@ -1965,11 +2115,41 @@ def main() -> int:
|
|||||||
"topk",
|
"topk",
|
||||||
"event_rate_mean",
|
"event_rate_mean",
|
||||||
"event_rate_median",
|
"event_rate_median",
|
||||||
|
"recall_mean",
|
||||||
|
"recall_median",
|
||||||
"n_total",
|
"n_total",
|
||||||
|
"n_valid_recall",
|
||||||
],
|
],
|
||||||
conc_rows,
|
conc_rows,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if conc_base_rows:
|
||||||
|
write_simple_csv(
|
||||||
|
os.path.join(
|
||||||
|
export_dir, "high_risk_cause_concentration_random_baseline.csv"),
|
||||||
|
[
|
||||||
|
"model_id",
|
||||||
|
"model_type",
|
||||||
|
"loss_type",
|
||||||
|
"age_encoder",
|
||||||
|
"cov_type",
|
||||||
|
"horizon",
|
||||||
|
"sex",
|
||||||
|
"topk",
|
||||||
|
"baseline_event_rate_mean",
|
||||||
|
"baseline_event_rate_p05",
|
||||||
|
"baseline_event_rate_p95",
|
||||||
|
"baseline_recall_mean",
|
||||||
|
"baseline_recall_p05",
|
||||||
|
"baseline_recall_p95",
|
||||||
|
"n_total",
|
||||||
|
"n_valid_recall",
|
||||||
|
"k_total",
|
||||||
|
"baseline_method",
|
||||||
|
],
|
||||||
|
conc_base_rows,
|
||||||
|
)
|
||||||
|
|
||||||
# Manifest markdown (stable, user-facing)
|
# Manifest markdown (stable, user-facing)
|
||||||
manifest_path = os.path.join(export_dir, "eval_exports_manifest.md")
|
manifest_path = os.path.join(export_dir, "eval_exports_manifest.md")
|
||||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
with open(manifest_path, "w", encoding="utf-8") as f:
|
||||||
@@ -1981,7 +2161,8 @@ def main() -> int:
|
|||||||
"- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n"
|
"- focus_causes.csv: The deterministically selected focus causes (Death + focus_k). Intended plot: bar of event support + label table.\n"
|
||||||
"- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n"
|
"- lift_capture_points.csv: Capture/precision at top {1,5,10,20}% risk. Intended plot: bar/line showing event capture vs resources.\n"
|
||||||
"- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n"
|
"- lift_capture_curve.csv (optional): Dense capture curve for k=1..N%. Intended plot: gain curve overlay across models.\n"
|
||||||
"- high_risk_cause_concentration.csv: Event Rate@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart of Event Rate@K vs K.\n"
|
"- high_risk_cause_concentration.csv: Event Rate@K and Recall@K when ranking ALL causes per individual by predicted CIF at each horizon (K from --cause_concentration_topk). Intended plot: line chart vs K.\n"
|
||||||
|
"- high_risk_cause_concentration_random_baseline.csv (optional): Random-ranking baseline for Event Rate@K and Recall@K with an uncertainty range (enabled by --cause_concentration_write_random_baseline).\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
meta = {
|
meta = {
|
||||||
|
|||||||
Reference in New Issue
Block a user