Enhance next-event evaluation with age-bin metrics and diagnostic AUC outputs

This commit is contained in:
2026-01-17 15:31:12 +08:00
parent 197842b1a6
commit fcd948818c
2 changed files with 348 additions and 183 deletions

View File

@@ -41,8 +41,8 @@ def parse_args() -> argparse.Namespace:
"--age_bins",
type=str,
nargs="+",
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
)
p.add_argument(
@@ -80,9 +80,14 @@ def _compute_next_event_metrics(
tau_short: float,
min_pos: int,
) -> tuple[list[dict], pd.DataFrame]:
"""Compute next-event metrics on a given subset.
"""Compute next-event *primary* metrics on a given subset.
Definitions are unchanged from the original script.
Implements 评估方案.md (Next-event):
- score_k = CIF_k(tau_short)
- Hit@K / MRR are computed on records with an observed next-event.
Returns (metrics_rows, diag_df). diag_df is a diagnostic per-cause AUC table
based on whether the cause occurs within (t0, t0+tau_short] (display-only).
"""
n_records_total = int(y_next.size)
eligible = y_next >= 0
@@ -99,49 +104,40 @@ def _compute_next_event_metrics(
{"metric": "tau_short_years", "value": float(tau_short)})
K = int(scores.shape[1])
# Diagnostic: build per-cause AUC using within-window labels.
# This is NOT a primary metric (no IPCW / censoring adjustment).
diag_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.zeros((K,), dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
if n_records_total == 0:
per_cause_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.zeros((K,), dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
return metrics_rows, diag_df
# If no eligible, keep coverage but leave accuracy-like metrics as NaN.
if n_eligible == 0:
per_cause_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.full((K,), n_records_total, dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
return metrics_rows, diag_df
scores_e = scores[eligible]
y_e = y_next[eligible]
pred = scores_e.argmax(axis=1)
acc = float((pred == y_e).mean())
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
metrics_rows.append({"metric": "hitrate_at_1", "value": acc})
# MRR
order = np.argsort(-scores_e, axis=1, kind="mergesort")
@@ -158,17 +154,53 @@ def _compute_next_event_metrics(
metrics_rows.append({"metric": f"hitrate_at_{k}",
"value": float(hit.mean())})
# Macro OvR AUC per cause (optional)
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
# Diagnostic per-cause AUC is computed outside (needs future events), so keep placeholder here.
_ = min_pos
return metrics_rows, diag_df
def _compute_within_window_auc(
*,
scores: np.ndarray,
records: list,
tau_short: float,
min_pos: int,
) -> pd.DataFrame:
"""Diagnostic-only per-cause AUC.
Label definition (event-driven, approximate; no IPCW):
y[i,k]=1 iff at least one event of cause k occurs in (t0, t0+tau_short].
"""
n_records = int(len(records))
if n_records == 0:
return pd.DataFrame(
columns=["cause_id", "n_pos", "n_neg", "auc", "included"],
)
K = int(scores.shape[1])
y = np.zeros((n_records, K), dtype=np.int8)
tau = float(tau_short)
# Build labels from future events.
for i, r in enumerate(records):
if r.future_causes.size == 0:
continue
m = r.future_dt_years <= tau
if not np.any(m):
continue
y[i, r.future_causes[m]] = 1
n_pos = y.sum(axis=0).astype(np.int64)
n_neg = (int(n_records) - n_pos).astype(np.int64)
auc = np.full((K,), np.nan, dtype=np.float64)
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0))
for k in candidates:
auc[k] = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
auc[k] = roc_auc_ovr(y[:, k].astype(np.int32),
scores[:, k].astype(np.float64))
included = (n_pos >= int(min_pos)) & (n_neg > 0)
per_cause_df = pd.DataFrame(
return pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
@@ -178,15 +210,6 @@ def _compute_next_event_metrics(
}
)
aucs = auc[np.isfinite(auc)]
if aucs.size > 0:
metrics_rows.append(
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
else:
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
return metrics_rows, per_cause_df
def main() -> None:
args = parse_args()
@@ -246,19 +269,7 @@ def main() -> None:
)
# Overall (preserve existing output files/shape)
metrics_rows, per_cause_df = _compute_next_event_metrics(
scores=scores,
y_next=y_next,
tau_short=tau,
min_pos=int(args.min_pos),
)
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
per_cause_df.to_csv(out_pc, index=False)
# By age bin (new outputs)
# Strict protocol: evaluate independently per age bin (no mixing).
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
age_bins_days = age_bins_years * DAYS_PER_YEAR
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
@@ -266,7 +277,7 @@ def main() -> None:
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
per_bin_metric_rows: List[dict] = []
per_bin_cause_parts: List[pd.DataFrame] = []
per_bin_auc_parts: List[pd.DataFrame] = []
for b in range(len(age_bins_years) - 1):
lo = float(age_bins_years[b])
hi = float(age_bins_years[b + 1])
@@ -274,6 +285,7 @@ def main() -> None:
m = bin_idx == b
m_scores = scores[m]
m_y = y_next[m]
m_records = [r for r, keep in zip(records, m) if bool(keep)]
m_rows, m_pc = _compute_next_event_metrics(
scores=m_scores,
y_next=m_y,
@@ -282,25 +294,32 @@ def main() -> None:
)
for row in m_rows:
per_bin_metric_rows.append({"age_bin": label, **row})
m_pc = m_pc.copy()
m_pc.insert(0, "age_bin", label)
per_bin_cause_parts.append(m_pc)
m_auc = _compute_within_window_auc(
scores=m_scores,
records=m_records,
tau_short=tau,
min_pos=int(args.min_pos),
)
m_auc.insert(0, "age_bin", label)
m_auc.insert(1, "tau_short_years", float(tau))
per_bin_auc_parts.append(m_auc)
out_metrics_bins = os.path.join(
run_dir, "next_event_metrics_by_age_bin.csv")
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
out_pc_bins = os.path.join(run_dir, "next_event_per_cause_by_age_bin.csv")
if per_bin_cause_parts:
pd.concat(per_bin_cause_parts, ignore_index=True).to_csv(
out_pc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "cause_id", "n_pos", "n_neg",
"auc", "included"]).to_csv(out_pc_bins, index=False)
print(f"Wrote {out_metrics}")
print(f"Wrote {out_pc}")
out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv")
if per_bin_auc_parts:
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
out_auc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "tau_short_years", "cause_id", "n_pos",
"n_neg", "auc", "included"]).to_csv(out_auc_bins, index=False)
print("PRIMARY METRICS: Hit@K / MRR are reported per age bin.")
print("DIAGNOSTICS ONLY: AUC table is event-driven approximate (no IPCW).")
print(f"Wrote {out_metrics_bins}")
print(f"Wrote {out_pc_bins}")
print(f"Wrote {out_auc_bins}")
if __name__ == "__main__":