Enhance next-event evaluation with age-bin metrics and diagnostic AUC outputs
This commit is contained in:
@@ -41,8 +41,8 @@ def parse_args() -> argparse.Namespace:
|
||||
"--age_bins",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["40", "45", "50", "55", "60", "65", "70", "75", "inf"],
|
||||
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 75 inf)",
|
||||
default=["40", "45", "50", "55", "60", "65", "70", "inf"],
|
||||
help="Age bin boundaries in years (default: 40 45 50 55 60 65 70 inf)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -80,9 +80,14 @@ def _compute_next_event_metrics(
|
||||
tau_short: float,
|
||||
min_pos: int,
|
||||
) -> tuple[list[dict], pd.DataFrame]:
|
||||
"""Compute next-event metrics on a given subset.
|
||||
"""Compute next-event *primary* metrics on a given subset.
|
||||
|
||||
Definitions are unchanged from the original script.
|
||||
Implements 评估方案.md (Next-event):
|
||||
- score_k = CIF_k(tau_short)
|
||||
- Hit@K / MRR are computed on records with an observed next-event.
|
||||
|
||||
Returns (metrics_rows, diag_df). diag_df is a diagnostic per-cause AUC table
|
||||
based on whether the cause occurs within (t0, t0+tau_short] (display-only).
|
||||
"""
|
||||
n_records_total = int(y_next.size)
|
||||
eligible = y_next >= 0
|
||||
@@ -99,49 +104,40 @@ def _compute_next_event_metrics(
|
||||
{"metric": "tau_short_years", "value": float(tau_short)})
|
||||
|
||||
K = int(scores.shape[1])
|
||||
# Diagnostic: build per-cause AUC using within-window labels.
|
||||
# This is NOT a primary metric (no IPCW / censoring adjustment).
|
||||
diag_df = pd.DataFrame(
|
||||
{
|
||||
"cause_id": np.arange(K, dtype=np.int64),
|
||||
"n_pos": np.zeros((K,), dtype=np.int64),
|
||||
"n_neg": np.zeros((K,), dtype=np.int64),
|
||||
"auc": np.full((K,), np.nan, dtype=np.float64),
|
||||
"included": np.zeros((K,), dtype=bool),
|
||||
}
|
||||
)
|
||||
if n_records_total == 0:
|
||||
per_cause_df = pd.DataFrame(
|
||||
{
|
||||
"cause_id": np.arange(K, dtype=np.int64),
|
||||
"n_pos": np.zeros((K,), dtype=np.int64),
|
||||
"n_neg": np.zeros((K,), dtype=np.int64),
|
||||
"auc": np.full((K,), np.nan, dtype=np.float64),
|
||||
"included": np.zeros((K,), dtype=bool),
|
||||
}
|
||||
)
|
||||
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
|
||||
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
|
||||
metrics_rows.append({"metric": "mrr", "value": float("nan")})
|
||||
for k in [1, 3, 5, 10, 20]:
|
||||
metrics_rows.append(
|
||||
{"metric": f"hitrate_at_{k}", "value": float("nan")})
|
||||
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
||||
return metrics_rows, per_cause_df
|
||||
return metrics_rows, diag_df
|
||||
|
||||
# If no eligible, keep coverage but leave accuracy-like metrics as NaN.
|
||||
if n_eligible == 0:
|
||||
per_cause_df = pd.DataFrame(
|
||||
{
|
||||
"cause_id": np.arange(K, dtype=np.int64),
|
||||
"n_pos": np.zeros((K,), dtype=np.int64),
|
||||
"n_neg": np.full((K,), n_records_total, dtype=np.int64),
|
||||
"auc": np.full((K,), np.nan, dtype=np.float64),
|
||||
"included": np.zeros((K,), dtype=bool),
|
||||
}
|
||||
)
|
||||
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
|
||||
metrics_rows.append({"metric": "hitrate_at_1", "value": float("nan")})
|
||||
metrics_rows.append({"metric": "mrr", "value": float("nan")})
|
||||
for k in [1, 3, 5, 10, 20]:
|
||||
metrics_rows.append(
|
||||
{"metric": f"hitrate_at_{k}", "value": float("nan")})
|
||||
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
||||
return metrics_rows, per_cause_df
|
||||
return metrics_rows, diag_df
|
||||
|
||||
scores_e = scores[eligible]
|
||||
y_e = y_next[eligible]
|
||||
|
||||
pred = scores_e.argmax(axis=1)
|
||||
acc = float((pred == y_e).mean())
|
||||
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
|
||||
metrics_rows.append({"metric": "hitrate_at_1", "value": acc})
|
||||
|
||||
# MRR
|
||||
order = np.argsort(-scores_e, axis=1, kind="mergesort")
|
||||
@@ -158,17 +154,53 @@ def _compute_next_event_metrics(
|
||||
metrics_rows.append({"metric": f"hitrate_at_{k}",
|
||||
"value": float(hit.mean())})
|
||||
|
||||
# Macro OvR AUC per cause (optional)
|
||||
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
|
||||
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
|
||||
# Diagnostic per-cause AUC is computed outside (needs future events), so keep placeholder here.
|
||||
_ = min_pos
|
||||
return metrics_rows, diag_df
|
||||
|
||||
|
||||
def _compute_within_window_auc(
|
||||
*,
|
||||
scores: np.ndarray,
|
||||
records: list,
|
||||
tau_short: float,
|
||||
min_pos: int,
|
||||
) -> pd.DataFrame:
|
||||
"""Diagnostic-only per-cause AUC.
|
||||
|
||||
Label definition (event-driven, approximate; no IPCW):
|
||||
y[i,k]=1 iff at least one event of cause k occurs in (t0, t0+tau_short].
|
||||
"""
|
||||
n_records = int(len(records))
|
||||
if n_records == 0:
|
||||
return pd.DataFrame(
|
||||
columns=["cause_id", "n_pos", "n_neg", "auc", "included"],
|
||||
)
|
||||
|
||||
K = int(scores.shape[1])
|
||||
y = np.zeros((n_records, K), dtype=np.int8)
|
||||
tau = float(tau_short)
|
||||
|
||||
# Build labels from future events.
|
||||
for i, r in enumerate(records):
|
||||
if r.future_causes.size == 0:
|
||||
continue
|
||||
m = r.future_dt_years <= tau
|
||||
if not np.any(m):
|
||||
continue
|
||||
y[i, r.future_causes[m]] = 1
|
||||
|
||||
n_pos = y.sum(axis=0).astype(np.int64)
|
||||
n_neg = (int(n_records) - n_pos).astype(np.int64)
|
||||
|
||||
auc = np.full((K,), np.nan, dtype=np.float64)
|
||||
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0))
|
||||
for k in candidates:
|
||||
auc[k] = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
|
||||
auc[k] = roc_auc_ovr(y[:, k].astype(np.int32),
|
||||
scores[:, k].astype(np.float64))
|
||||
|
||||
included = (n_pos >= int(min_pos)) & (n_neg > 0)
|
||||
per_cause_df = pd.DataFrame(
|
||||
return pd.DataFrame(
|
||||
{
|
||||
"cause_id": np.arange(K, dtype=np.int64),
|
||||
"n_pos": n_pos,
|
||||
@@ -178,15 +210,6 @@ def _compute_next_event_metrics(
|
||||
}
|
||||
)
|
||||
|
||||
aucs = auc[np.isfinite(auc)]
|
||||
if aucs.size > 0:
|
||||
metrics_rows.append(
|
||||
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
|
||||
else:
|
||||
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
||||
|
||||
return metrics_rows, per_cause_df
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
@@ -246,19 +269,7 @@ def main() -> None:
|
||||
)
|
||||
|
||||
# Overall (preserve existing output files/shape)
|
||||
metrics_rows, per_cause_df = _compute_next_event_metrics(
|
||||
scores=scores,
|
||||
y_next=y_next,
|
||||
tau_short=tau,
|
||||
min_pos=int(args.min_pos),
|
||||
)
|
||||
|
||||
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
|
||||
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
|
||||
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
|
||||
per_cause_df.to_csv(out_pc, index=False)
|
||||
|
||||
# By age bin (new outputs)
|
||||
# Strict protocol: evaluate independently per age bin (no mixing).
|
||||
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
|
||||
age_bins_days = age_bins_years * DAYS_PER_YEAR
|
||||
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
|
||||
@@ -266,7 +277,7 @@ def main() -> None:
|
||||
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
|
||||
|
||||
per_bin_metric_rows: List[dict] = []
|
||||
per_bin_cause_parts: List[pd.DataFrame] = []
|
||||
per_bin_auc_parts: List[pd.DataFrame] = []
|
||||
for b in range(len(age_bins_years) - 1):
|
||||
lo = float(age_bins_years[b])
|
||||
hi = float(age_bins_years[b + 1])
|
||||
@@ -274,6 +285,7 @@ def main() -> None:
|
||||
m = bin_idx == b
|
||||
m_scores = scores[m]
|
||||
m_y = y_next[m]
|
||||
m_records = [r for r, keep in zip(records, m) if bool(keep)]
|
||||
m_rows, m_pc = _compute_next_event_metrics(
|
||||
scores=m_scores,
|
||||
y_next=m_y,
|
||||
@@ -282,25 +294,32 @@ def main() -> None:
|
||||
)
|
||||
for row in m_rows:
|
||||
per_bin_metric_rows.append({"age_bin": label, **row})
|
||||
m_pc = m_pc.copy()
|
||||
m_pc.insert(0, "age_bin", label)
|
||||
per_bin_cause_parts.append(m_pc)
|
||||
m_auc = _compute_within_window_auc(
|
||||
scores=m_scores,
|
||||
records=m_records,
|
||||
tau_short=tau,
|
||||
min_pos=int(args.min_pos),
|
||||
)
|
||||
m_auc.insert(0, "age_bin", label)
|
||||
m_auc.insert(1, "tau_short_years", float(tau))
|
||||
per_bin_auc_parts.append(m_auc)
|
||||
|
||||
out_metrics_bins = os.path.join(
|
||||
run_dir, "next_event_metrics_by_age_bin.csv")
|
||||
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
|
||||
out_pc_bins = os.path.join(run_dir, "next_event_per_cause_by_age_bin.csv")
|
||||
if per_bin_cause_parts:
|
||||
pd.concat(per_bin_cause_parts, ignore_index=True).to_csv(
|
||||
out_pc_bins, index=False)
|
||||
else:
|
||||
pd.DataFrame(columns=["age_bin", "cause_id", "n_pos", "n_neg",
|
||||
"auc", "included"]).to_csv(out_pc_bins, index=False)
|
||||
|
||||
print(f"Wrote {out_metrics}")
|
||||
print(f"Wrote {out_pc}")
|
||||
out_auc_bins = os.path.join(run_dir, "next_event_auc_by_age_bin.csv")
|
||||
if per_bin_auc_parts:
|
||||
pd.concat(per_bin_auc_parts, ignore_index=True).to_csv(
|
||||
out_auc_bins, index=False)
|
||||
else:
|
||||
pd.DataFrame(columns=["age_bin", "tau_short_years", "cause_id", "n_pos",
|
||||
"n_neg", "auc", "included"]).to_csv(out_auc_bins, index=False)
|
||||
|
||||
print("PRIMARY METRICS: Hit@K / MRR are reported per age bin.")
|
||||
print("DIAGNOSTICS ONLY: AUC table is event-driven approximate (no IPCW).")
|
||||
print(f"Wrote {out_metrics_bins}")
|
||||
print(f"Wrote {out_pc_bins}")
|
||||
print(f"Wrote {out_auc_bins}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user