Refactor next-event evaluation logic and add age-bin metrics output
This commit is contained in:
@@ -27,6 +27,7 @@ from utils import (
|
||||
roc_auc_ovr,
|
||||
seed_everything,
|
||||
topk_indices,
|
||||
DAYS_PER_YEAR,
|
||||
)
|
||||
|
||||
|
||||
@@ -66,6 +67,127 @@ def parse_args() -> argparse.Namespace:
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def _format_age_bin_label(lo: float, hi: float) -> str:
|
||||
if np.isinf(hi):
|
||||
return f"[{lo}, inf)"
|
||||
return f"[{lo}, {hi})"
|
||||
|
||||
|
||||
def _compute_next_event_metrics(
|
||||
*,
|
||||
scores: np.ndarray,
|
||||
y_next: np.ndarray,
|
||||
tau_short: float,
|
||||
min_pos: int,
|
||||
) -> tuple[list[dict], pd.DataFrame]:
|
||||
"""Compute next-event metrics on a given subset.
|
||||
|
||||
Definitions are unchanged from the original script.
|
||||
"""
|
||||
n_records_total = int(y_next.size)
|
||||
eligible = y_next >= 0
|
||||
n_eligible = int(eligible.sum())
|
||||
coverage = float(
|
||||
n_eligible / n_records_total) if n_records_total > 0 else 0.0
|
||||
|
||||
metrics_rows: List[dict] = []
|
||||
metrics_rows.append({"metric": "n_records_total", "value": n_records_total})
|
||||
metrics_rows.append(
|
||||
{"metric": "n_next_event_eligible", "value": n_eligible})
|
||||
metrics_rows.append({"metric": "coverage", "value": coverage})
|
||||
metrics_rows.append(
|
||||
{"metric": "tau_short_years", "value": float(tau_short)})
|
||||
|
||||
K = int(scores.shape[1])
|
||||
if n_records_total == 0:
|
||||
per_cause_df = pd.DataFrame(
|
||||
{
|
||||
"cause_id": np.arange(K, dtype=np.int64),
|
||||
"n_pos": np.zeros((K,), dtype=np.int64),
|
||||
"n_neg": np.zeros((K,), dtype=np.int64),
|
||||
"auc": np.full((K,), np.nan, dtype=np.float64),
|
||||
"included": np.zeros((K,), dtype=bool),
|
||||
}
|
||||
)
|
||||
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
|
||||
metrics_rows.append({"metric": "mrr", "value": float("nan")})
|
||||
for k in [1, 3, 5, 10, 20]:
|
||||
metrics_rows.append(
|
||||
{"metric": f"hitrate_at_{k}", "value": float("nan")})
|
||||
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
||||
return metrics_rows, per_cause_df
|
||||
|
||||
# If no eligible, keep coverage but leave accuracy-like metrics as NaN.
|
||||
if n_eligible == 0:
|
||||
per_cause_df = pd.DataFrame(
|
||||
{
|
||||
"cause_id": np.arange(K, dtype=np.int64),
|
||||
"n_pos": np.zeros((K,), dtype=np.int64),
|
||||
"n_neg": np.full((K,), n_records_total, dtype=np.int64),
|
||||
"auc": np.full((K,), np.nan, dtype=np.float64),
|
||||
"included": np.zeros((K,), dtype=bool),
|
||||
}
|
||||
)
|
||||
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
|
||||
metrics_rows.append({"metric": "mrr", "value": float("nan")})
|
||||
for k in [1, 3, 5, 10, 20]:
|
||||
metrics_rows.append(
|
||||
{"metric": f"hitrate_at_{k}", "value": float("nan")})
|
||||
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
||||
return metrics_rows, per_cause_df
|
||||
|
||||
scores_e = scores[eligible]
|
||||
y_e = y_next[eligible]
|
||||
|
||||
pred = scores_e.argmax(axis=1)
|
||||
acc = float((pred == y_e).mean())
|
||||
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
|
||||
|
||||
# MRR
|
||||
order = np.argsort(-scores_e, axis=1, kind="mergesort")
|
||||
ranks = np.empty(y_e.shape[0], dtype=np.int32)
|
||||
for i in range(y_e.shape[0]):
|
||||
ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1
|
||||
mrr = float((1.0 / ranks).mean())
|
||||
metrics_rows.append({"metric": "mrr", "value": mrr})
|
||||
|
||||
# HitRate@K
|
||||
for k in [1, 3, 5, 10, 20]:
|
||||
topk = topk_indices(scores_e, k)
|
||||
hit = (topk == y_e[:, None]).any(axis=1)
|
||||
metrics_rows.append({"metric": f"hitrate_at_{k}",
|
||||
"value": float(hit.mean())})
|
||||
|
||||
# Macro OvR AUC per cause (optional)
|
||||
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
|
||||
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
|
||||
|
||||
auc = np.full((K,), np.nan, dtype=np.float64)
|
||||
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0))
|
||||
for k in candidates:
|
||||
auc[k] = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
|
||||
|
||||
included = (n_pos >= int(min_pos)) & (n_neg > 0)
|
||||
per_cause_df = pd.DataFrame(
|
||||
{
|
||||
"cause_id": np.arange(K, dtype=np.int64),
|
||||
"n_pos": n_pos,
|
||||
"n_neg": n_neg,
|
||||
"auc": auc,
|
||||
"included": included,
|
||||
}
|
||||
)
|
||||
|
||||
aucs = auc[np.isfinite(auc)]
|
||||
if aucs.size > 0:
|
||||
metrics_rows.append(
|
||||
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
|
||||
else:
|
||||
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
||||
|
||||
return metrics_rows, per_cause_df
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
seed_everything(args.seed)
|
||||
@@ -118,92 +240,68 @@ def main() -> None:
|
||||
if scores.ndim == 3:
|
||||
scores = scores[:, :, 0]
|
||||
|
||||
n_records_total = len(records)
|
||||
y_next = np.array(
|
||||
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
|
||||
for r in records],
|
||||
dtype=np.int64,
|
||||
)
|
||||
eligible = y_next >= 0
|
||||
n_eligible = int(eligible.sum())
|
||||
coverage = float(
|
||||
n_eligible / n_records_total) if n_records_total > 0 else 0.0
|
||||
|
||||
metrics_rows: List[dict] = []
|
||||
metrics_rows.append({"metric": "n_records_total", "value": n_records_total})
|
||||
metrics_rows.append(
|
||||
{"metric": "n_next_event_eligible", "value": n_eligible})
|
||||
metrics_rows.append({"metric": "coverage", "value": coverage})
|
||||
metrics_rows.append({"metric": "tau_short_years", "value": tau})
|
||||
|
||||
if n_eligible == 0:
|
||||
out_path = os.path.join(run_dir, "next_event_metrics.csv")
|
||||
pd.DataFrame(metrics_rows).to_csv(out_path, index=False)
|
||||
print(f"No eligible records; wrote {out_path}")
|
||||
return
|
||||
|
||||
scores_e = scores[eligible]
|
||||
y_e = y_next[eligible]
|
||||
|
||||
pred = scores_e.argmax(axis=1)
|
||||
acc = float((pred == y_e).mean())
|
||||
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
|
||||
|
||||
# MRR
|
||||
order = np.argsort(-scores_e, axis=1, kind="mergesort")
|
||||
ranks = np.empty(y_e.shape[0], dtype=np.int32)
|
||||
for i in range(y_e.shape[0]):
|
||||
ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1
|
||||
mrr = float((1.0 / ranks).mean())
|
||||
metrics_rows.append({"metric": "mrr", "value": mrr})
|
||||
|
||||
# HitRate@K
|
||||
for k in [1, 3, 5, 10, 20]:
|
||||
topk = topk_indices(scores_e, k)
|
||||
hit = (topk == y_e[:, None]).any(axis=1)
|
||||
metrics_rows.append({"metric": f"hitrate_at_{k}",
|
||||
"value": float(hit.mean())})
|
||||
|
||||
# Macro OvR AUC per cause (optional)
|
||||
K = scores.shape[1]
|
||||
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
|
||||
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
|
||||
|
||||
auc = np.full((K,), np.nan, dtype=np.float64)
|
||||
min_pos = int(args.min_pos)
|
||||
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
|
||||
for k in candidates:
|
||||
auc_k = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
|
||||
auc[k] = auc_k
|
||||
|
||||
included = (n_pos >= min_pos) & (n_neg > 0)
|
||||
per_cause_df = pd.DataFrame(
|
||||
{
|
||||
"cause_id": np.arange(K, dtype=np.int64),
|
||||
"n_pos": n_pos,
|
||||
"n_neg": n_neg,
|
||||
"auc": auc,
|
||||
"included": included,
|
||||
}
|
||||
# Overall (preserve existing output files/shape)
|
||||
metrics_rows, per_cause_df = _compute_next_event_metrics(
|
||||
scores=scores,
|
||||
y_next=y_next,
|
||||
tau_short=tau,
|
||||
min_pos=int(args.min_pos),
|
||||
)
|
||||
|
||||
aucs = auc[np.isfinite(auc)]
|
||||
|
||||
if aucs.size > 0:
|
||||
metrics_rows.append(
|
||||
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
|
||||
else:
|
||||
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
||||
|
||||
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
|
||||
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
|
||||
|
||||
# optional per-cause
|
||||
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
|
||||
per_cause_df.to_csv(out_pc, index=False)
|
||||
|
||||
# By age bin (new outputs)
|
||||
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
|
||||
age_bins_days = age_bins_years * DAYS_PER_YEAR
|
||||
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
|
||||
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
|
||||
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
|
||||
|
||||
per_bin_metric_rows: List[dict] = []
|
||||
per_bin_cause_parts: List[pd.DataFrame] = []
|
||||
for b in range(len(age_bins_years) - 1):
|
||||
lo = float(age_bins_years[b])
|
||||
hi = float(age_bins_years[b + 1])
|
||||
label = _format_age_bin_label(lo, hi)
|
||||
m = bin_idx == b
|
||||
m_scores = scores[m]
|
||||
m_y = y_next[m]
|
||||
m_rows, m_pc = _compute_next_event_metrics(
|
||||
scores=m_scores,
|
||||
y_next=m_y,
|
||||
tau_short=tau,
|
||||
min_pos=int(args.min_pos),
|
||||
)
|
||||
for row in m_rows:
|
||||
per_bin_metric_rows.append({"age_bin": label, **row})
|
||||
m_pc = m_pc.copy()
|
||||
m_pc.insert(0, "age_bin", label)
|
||||
per_bin_cause_parts.append(m_pc)
|
||||
|
||||
out_metrics_bins = os.path.join(
|
||||
run_dir, "next_event_metrics_by_age_bin.csv")
|
||||
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
|
||||
out_pc_bins = os.path.join(run_dir, "next_event_per_cause_by_age_bin.csv")
|
||||
if per_bin_cause_parts:
|
||||
pd.concat(per_bin_cause_parts, ignore_index=True).to_csv(
|
||||
out_pc_bins, index=False)
|
||||
else:
|
||||
pd.DataFrame(columns=["age_bin", "cause_id", "n_pos", "n_neg",
|
||||
"auc", "included"]).to_csv(out_pc_bins, index=False)
|
||||
|
||||
print(f"Wrote {out_metrics}")
|
||||
print(f"Wrote {out_pc}")
|
||||
print(f"Wrote {out_metrics_bins}")
|
||||
print(f"Wrote {out_pc_bins}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user