From 7840a4c53e10281e1f01c0a909b6ccec7a8da07a Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sat, 17 Jan 2026 14:09:50 +0800 Subject: [PATCH] Refactor next-event evaluation logic and add age-bin metrics output --- evaluate_next_event.py | 240 +++++++++++++++++++++++++++++------------ 1 file changed, 169 insertions(+), 71 deletions(-) diff --git a/evaluate_next_event.py b/evaluate_next_event.py index 6c811e5..130ee69 100644 --- a/evaluate_next_event.py +++ b/evaluate_next_event.py @@ -27,6 +27,7 @@ from utils import ( roc_auc_ovr, seed_everything, topk_indices, + DAYS_PER_YEAR, ) @@ -66,6 +67,127 @@ def parse_args() -> argparse.Namespace: return p.parse_args() +def _format_age_bin_label(lo: float, hi: float) -> str: + if np.isinf(hi): + return f"[{lo}, inf)" + return f"[{lo}, {hi})" + + +def _compute_next_event_metrics( + *, + scores: np.ndarray, + y_next: np.ndarray, + tau_short: float, + min_pos: int, +) -> tuple[list[dict], pd.DataFrame]: + """Compute next-event metrics on a given subset. + + Definitions are unchanged from the original script. + """ + n_records_total = int(y_next.size) + eligible = y_next >= 0 + n_eligible = int(eligible.sum()) + coverage = float( + n_eligible / n_records_total) if n_records_total > 0 else 0.0 + + metrics_rows: List[dict] = [] + metrics_rows.append({"metric": "n_records_total", "value": n_records_total}) + metrics_rows.append( + {"metric": "n_next_event_eligible", "value": n_eligible}) + metrics_rows.append({"metric": "coverage", "value": coverage}) + metrics_rows.append( + {"metric": "tau_short_years", "value": float(tau_short)}) + + K = int(scores.shape[1]) + if n_records_total == 0: + per_cause_df = pd.DataFrame( + { + "cause_id": np.arange(K, dtype=np.int64), + "n_pos": np.zeros((K,), dtype=np.int64), + "n_neg": np.zeros((K,), dtype=np.int64), + "auc": np.full((K,), np.nan, dtype=np.float64), + "included": np.zeros((K,), dtype=bool), + } + ) + metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")}) + metrics_rows.append({"metric": "mrr", "value": float("nan")}) + for k in [1, 3, 5, 10, 20]: + metrics_rows.append( + {"metric": f"hitrate_at_{k}", "value": float("nan")}) + metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) + return metrics_rows, per_cause_df + + # If no eligible, keep coverage but leave accuracy-like metrics as NaN. + if n_eligible == 0: + per_cause_df = pd.DataFrame( + { + "cause_id": np.arange(K, dtype=np.int64), + "n_pos": np.zeros((K,), dtype=np.int64), + "n_neg": np.full((K,), n_records_total, dtype=np.int64), + "auc": np.full((K,), np.nan, dtype=np.float64), + "included": np.zeros((K,), dtype=bool), + } + ) + metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")}) + metrics_rows.append({"metric": "mrr", "value": float("nan")}) + for k in [1, 3, 5, 10, 20]: + metrics_rows.append( + {"metric": f"hitrate_at_{k}", "value": float("nan")}) + metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) + return metrics_rows, per_cause_df + + scores_e = scores[eligible] + y_e = y_next[eligible] + + pred = scores_e.argmax(axis=1) + acc = float((pred == y_e).mean()) + metrics_rows.append({"metric": "top1_accuracy", "value": acc}) + + # MRR + order = np.argsort(-scores_e, axis=1, kind="mergesort") + ranks = np.empty(y_e.shape[0], dtype=np.int32) + for i in range(y_e.shape[0]): + ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1 + mrr = float((1.0 / ranks).mean()) + metrics_rows.append({"metric": "mrr", "value": mrr}) + + # HitRate@K + for k in [1, 3, 5, 10, 20]: + topk = topk_indices(scores_e, k) + hit = (topk == y_e[:, None]).any(axis=1) + metrics_rows.append({"metric": f"hitrate_at_{k}", + "value": float(hit.mean())}) + + # Macro OvR AUC per cause (optional) + n_pos = np.bincount(y_e, minlength=K).astype(np.int64) + n_neg = (int(y_e.size) - n_pos).astype(np.int64) + + auc = np.full((K,), np.nan, dtype=np.float64) + candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0)) + for k in candidates: + auc[k] = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k]) + + included = (n_pos >= int(min_pos)) & (n_neg > 0) + per_cause_df = pd.DataFrame( + { + "cause_id": np.arange(K, dtype=np.int64), + "n_pos": n_pos, + "n_neg": n_neg, + "auc": auc, + "included": included, + } + ) + + aucs = auc[np.isfinite(auc)] + if aucs.size > 0: + metrics_rows.append( + {"metric": "macro_ovr_auc", "value": float(np.mean(aucs))}) + else: + metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) + + return metrics_rows, per_cause_df + + def main() -> None: args = parse_args() seed_everything(args.seed) @@ -118,92 +240,68 @@ def main() -> None: if scores.ndim == 3: scores = scores[:, :, 0] - n_records_total = len(records) y_next = np.array( [(-1 if r.next_event_cause is None else int(r.next_event_cause)) for r in records], dtype=np.int64, ) - eligible = y_next >= 0 - n_eligible = int(eligible.sum()) - coverage = float( - n_eligible / n_records_total) if n_records_total > 0 else 0.0 - metrics_rows: List[dict] = [] - metrics_rows.append({"metric": "n_records_total", "value": n_records_total}) - metrics_rows.append( - {"metric": "n_next_event_eligible", "value": n_eligible}) - metrics_rows.append({"metric": "coverage", "value": coverage}) - metrics_rows.append({"metric": "tau_short_years", "value": tau}) - - if n_eligible == 0: - out_path = os.path.join(run_dir, "next_event_metrics.csv") - pd.DataFrame(metrics_rows).to_csv(out_path, index=False) - print(f"No eligible records; wrote {out_path}") - return - - scores_e = scores[eligible] - y_e = y_next[eligible] - - pred = scores_e.argmax(axis=1) - acc = float((pred == y_e).mean()) - metrics_rows.append({"metric": "top1_accuracy", "value": acc}) - - # MRR - order = np.argsort(-scores_e, axis=1, kind="mergesort") - ranks = np.empty(y_e.shape[0], dtype=np.int32) - for i in range(y_e.shape[0]): - ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1 - mrr = float((1.0 / ranks).mean()) - metrics_rows.append({"metric": "mrr", "value": mrr}) - - # HitRate@K - for k in [1, 3, 5, 10, 20]: - topk = topk_indices(scores_e, k) - hit = (topk == y_e[:, None]).any(axis=1) - metrics_rows.append({"metric": f"hitrate_at_{k}", - "value": float(hit.mean())}) - - # Macro OvR AUC per cause (optional) - K = scores.shape[1] - n_pos = np.bincount(y_e, minlength=K).astype(np.int64) - n_neg = (int(y_e.size) - n_pos).astype(np.int64) - - auc = np.full((K,), np.nan, dtype=np.float64) - min_pos = int(args.min_pos) - candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0)) - for k in candidates: - auc_k = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k]) - auc[k] = auc_k - - included = (n_pos >= min_pos) & (n_neg > 0) - per_cause_df = pd.DataFrame( - { - "cause_id": np.arange(K, dtype=np.int64), - "n_pos": n_pos, - "n_neg": n_neg, - "auc": auc, - "included": included, - } + # Overall (preserve existing output files/shape) + metrics_rows, per_cause_df = _compute_next_event_metrics( + scores=scores, + y_next=y_next, + tau_short=tau, + min_pos=int(args.min_pos), ) - aucs = auc[np.isfinite(auc)] - - if aucs.size > 0: - metrics_rows.append( - {"metric": "macro_ovr_auc", "value": float(np.mean(aucs))}) - else: - metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")}) - out_metrics = os.path.join(run_dir, "next_event_metrics.csv") pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False) - - # optional per-cause out_pc = os.path.join(run_dir, "next_event_per_cause.csv") per_cause_df.to_csv(out_pc, index=False) + # By age bin (new outputs) + age_bins_years = np.asarray(age_bins_years, dtype=np.float64) + age_bins_days = age_bins_years * DAYS_PER_YEAR + # Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1}) + t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64) + bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1 + + per_bin_metric_rows: List[dict] = [] + per_bin_cause_parts: List[pd.DataFrame] = [] + for b in range(len(age_bins_years) - 1): + lo = float(age_bins_years[b]) + hi = float(age_bins_years[b + 1]) + label = _format_age_bin_label(lo, hi) + m = bin_idx == b + m_scores = scores[m] + m_y = y_next[m] + m_rows, m_pc = _compute_next_event_metrics( + scores=m_scores, + y_next=m_y, + tau_short=tau, + min_pos=int(args.min_pos), + ) + for row in m_rows: + per_bin_metric_rows.append({"age_bin": label, **row}) + m_pc = m_pc.copy() + m_pc.insert(0, "age_bin", label) + per_bin_cause_parts.append(m_pc) + + out_metrics_bins = os.path.join( + run_dir, "next_event_metrics_by_age_bin.csv") + pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False) + out_pc_bins = os.path.join(run_dir, "next_event_per_cause_by_age_bin.csv") + if per_bin_cause_parts: + pd.concat(per_bin_cause_parts, ignore_index=True).to_csv( + out_pc_bins, index=False) + else: + pd.DataFrame(columns=["age_bin", "cause_id", "n_pos", "n_neg", + "auc", "included"]).to_csv(out_pc_bins, index=False) + print(f"Wrote {out_metrics}") print(f"Wrote {out_pc}") + print(f"Wrote {out_metrics_bins}") + print(f"Wrote {out_pc_bins}") if __name__ == "__main__":