Refactor next-event evaluation to use next-token scores and implement clean control AUC metrics

This commit is contained in:
2026-01-17 23:34:19 +08:00
parent e56068e668
commit 4686b56336
2 changed files with 286 additions and 167 deletions

View File

@@ -23,20 +23,18 @@ from utils import (
load_checkpoint_into,
load_train_config,
parse_float_list,
predict_cifs,
roc_auc_ovr,
predict_next_token_logits,
get_auc_delong_var,
seed_everything,
topk_indices,
DAYS_PER_YEAR,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Evaluate next-event prediction using short-window CIF"
description="Evaluate next-event prediction using next-token scores"
)
p.add_argument("--run_dir", type=str, required=True)
p.add_argument("--tau_short", type=float, required=True, help="years")
p.add_argument(
"--age_bins",
type=str,
@@ -73,140 +71,69 @@ def _format_age_bin_label(lo: float, hi: float) -> str:
return f"[{lo}, {hi})"
def _compute_next_event_metrics(
*,
scores: np.ndarray,
y_next: np.ndarray,
tau_short: float,
min_pos: int,
) -> tuple[list[dict], pd.DataFrame]:
"""Compute next-event *primary* metrics on a given subset.
Implements 评估方案.md (Next-event):
- score_k = CIF_k(tau_short)
- Hit@K / MRR are computed on records with an observed next-event.
Returns (metrics_rows, diag_df). diag_df is a diagnostic per-cause AUC table
based on whether the cause occurs within (t0, t0+tau_short] (display-only).
"""
n_records_total = int(y_next.size)
eligible = y_next >= 0
n_eligible = int(eligible.sum())
coverage = float(
n_eligible / n_records_total) if n_records_total > 0 else 0.0
metrics_rows: List[dict] = []
metrics_rows.append({"metric": "n_records_total", "value": n_records_total})
metrics_rows.append(
{"metric": "n_next_event_eligible", "value": n_eligible})
metrics_rows.append({"metric": "coverage", "value": coverage})
metrics_rows.append(
{"metric": "tau_short_years", "value": float(tau_short)})
K = int(scores.shape[1])
# Diagnostic: build per-cause AUC using within-window labels.
# This is NOT a primary metric (no IPCW / censoring adjustment).
diag_df = pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": np.zeros((K,), dtype=np.int64),
"n_neg": np.zeros((K,), dtype=np.int64),
"auc": np.full((K,), np.nan, dtype=np.float64),
"included": np.zeros((K,), dtype=bool),
}
)
if n_records_total == 0:
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
return metrics_rows, diag_df
# If no eligible, keep coverage but leave accuracy-like metrics as NaN.
if n_eligible == 0:
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
metrics_rows.append({"metric": "mrr", "value": float("nan")})
for k in [1, 3, 5, 10, 20]:
metrics_rows.append(
{"metric": f"hitrate_at_{k}", "value": float("nan")})
return metrics_rows, diag_df
scores_e = scores[eligible]
y_e = y_next[eligible]
pred = scores_e.argmax(axis=1)
acc = float((pred == y_e).mean())
metrics_rows.append({"metric": "hitrate_at_1", "value": acc})
# MRR
order = np.argsort(-scores_e, axis=1, kind="mergesort")
ranks = np.empty(y_e.shape[0], dtype=np.int32)
for i in range(y_e.shape[0]):
ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1
mrr = float((1.0 / ranks).mean())
metrics_rows.append({"metric": "mrr", "value": mrr})
# HitRate@K
for k in [1, 3, 5, 10, 20]:
topk = topk_indices(scores_e, k)
hit = (topk == y_e[:, None]).any(axis=1)
metrics_rows.append({"metric": f"hitrate_at_{k}",
"value": float(hit.mean())})
# Diagnostic per-cause AUC is computed outside (needs future events), so keep placeholder here.
_ = min_pos
return metrics_rows, diag_df
def _compute_within_window_auc(
def _compute_next_event_auc_clean_control(
*,
scores: np.ndarray,
records: list,
tau_short: float,
min_pos: int,
) -> pd.DataFrame:
"""Diagnostic-only per-cause AUC.
"""Delphi-2M next-event AUC (clean control) per cause.
Label definition (event-driven, approximate; no IPCW):
y[i,k]=1 iff at least one event of cause k occurs in (t0, t0+tau_short].
Definitions per cause k:
- Case: next_event_cause == k
- Control (clean): next_event_cause != k AND k not in record.lifetime_causes
AUC is computed with DeLong variance.
"""
n_records = int(len(records))
if n_records == 0:
return pd.DataFrame(
columns=["cause_id", "n_pos", "n_neg", "auc", "included"],
columns=["cause_id", "n_case", "n_control", "auc", "auc_variance"],
)
K = int(scores.shape[1])
y = np.zeros((n_records, K), dtype=np.int8)
tau = float(tau_short)
# Build labels from future events.
for i, r in enumerate(records):
if r.future_causes.size == 0:
continue
m = r.future_dt_years <= tau
if not np.any(m):
continue
y[i, r.future_causes[m]] = 1
n_pos = y.sum(axis=0).astype(np.int64)
n_neg = (int(n_records) - n_pos).astype(np.int64)
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
for r in records],
dtype=np.int64,
)
auc = np.full((K,), np.nan, dtype=np.float64)
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0))
for k in candidates:
auc[k] = roc_auc_ovr(y[:, k].astype(np.int32),
scores[:, k].astype(np.float64))
var = np.full((K,), np.nan, dtype=np.float64)
n_case = np.zeros((K,), dtype=np.int64)
n_control = np.zeros((K,), dtype=np.int64)
for k in range(K):
case_mask = y_next == k
if not np.any(case_mask):
continue
# Clean controls: not next-event k AND never had k in their lifetime history.
control_mask = y_next != k
if np.any(control_mask):
clean = np.fromiter(
((k not in rec.lifetime_causes) for rec in records),
dtype=bool,
count=n_records,
)
control_mask = control_mask & clean
cs = scores[case_mask, k]
hs = scores[control_mask, k]
n_case[k] = int(cs.size)
n_control[k] = int(hs.size)
if cs.size == 0 or hs.size == 0:
continue
a, v = get_auc_delong_var(hs, cs)
auc[k] = float(a)
var[k] = float(v)
included = (n_pos >= int(min_pos)) & (n_neg > 0)
return pd.DataFrame(
{
"cause_id": np.arange(K, dtype=np.int64),
"n_pos": n_pos,
"n_neg": n_neg,
"n_case": n_case,
"n_control": n_control,
"auc": auc,
"included": included,
"auc_variance": var,
}
)
@@ -247,20 +174,15 @@ def main() -> None:
**dl_kwargs,
)
tau = float(args.tau_short)
scores = predict_cifs(
scores = predict_next_token_logits(
model,
head,
criterion,
loader,
[tau],
device=device,
show_progress=show_progress,
progress_desc="Inference (next-event)",
progress_desc="Inference (next-token)",
return_probs=True,
)
# scores shape: (N,K,1) for multi-taus; squeeze last
if scores.ndim == 3:
scores = scores[:, :, 0]
y_next = np.array(
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
@@ -284,24 +206,24 @@ def main() -> None:
label = _format_age_bin_label(lo, hi)
m = bin_idx == b
m_scores = scores[m]
m_y = y_next[m]
m_records = [r for r, keep in zip(records, m) if bool(keep)]
m_rows, m_pc = _compute_next_event_metrics(
scores=m_scores,
y_next=m_y,
tau_short=tau,
min_pos=int(args.min_pos),
)
for row in m_rows:
per_bin_metric_rows.append({"age_bin": label, **row})
m_auc = _compute_within_window_auc(
# Coverage metric for transparency (not Delphi-2M AUC itself).
m_y = y_next[m]
n_total = int(m_y.size)
n_eligible = int((m_y >= 0).sum())
coverage = float(n_eligible / n_total) if n_total > 0 else 0.0
per_bin_metric_rows.append(
{"age_bin": label, "metric": "n_records_total", "value": n_total})
per_bin_metric_rows.append(
{"age_bin": label, "metric": "n_next_event_eligible", "value": n_eligible})
per_bin_metric_rows.append(
{"age_bin": label, "metric": "coverage", "value": coverage})
m_auc = _compute_next_event_auc_clean_control(
scores=m_scores,
records=m_records,
tau_short=tau,
min_pos=int(args.min_pos),
)
m_auc.insert(0, "age_bin", label)
m_auc.insert(1, "tau_short_years", float(tau))
per_bin_auc_parts.append(m_auc)
out_metrics_bins = os.path.join(
@@ -313,11 +235,11 @@ def main() -> None:
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
out_auc_bins, index=False)
else:
pd.DataFrame(columns=["age_bin", "tau_short_years", "cause_id", "n_pos",
"n_neg", "auc", "included"]).to_csv(out_auc_bins, index=False)
pd.DataFrame(columns=["age_bin", "cause_id", "n_case", "n_control",
"auc", "auc_variance"]).to_csv(out_auc_bins, index=False)
print("PRIMARY METRICS: Hit@K / MRR are reported per age bin.")
print("DIAGNOSTICS ONLY: AUC table is event-driven approximate (no IPCW).")
print("PRIMARY METRICS: Per-cause AUC is reported per age bin using Delphi-2M clean controls.")
print("EVAL METHOD: DeLong AUC variance is reported (per cause).")
print(f"Wrote {out_metrics_bins}")
print(f"Wrote {out_auc_bins}")