Refactor next-event evaluation logic and add age-bin metrics output
This commit is contained in:
@@ -27,6 +27,7 @@ from utils import (
|
|||||||
roc_auc_ovr,
|
roc_auc_ovr,
|
||||||
seed_everything,
|
seed_everything,
|
||||||
topk_indices,
|
topk_indices,
|
||||||
|
DAYS_PER_YEAR,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -66,6 +67,127 @@ def parse_args() -> argparse.Namespace:
|
|||||||
return p.parse_args()
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _format_age_bin_label(lo: float, hi: float) -> str:
|
||||||
|
if np.isinf(hi):
|
||||||
|
return f"[{lo}, inf)"
|
||||||
|
return f"[{lo}, {hi})"
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_next_event_metrics(
|
||||||
|
*,
|
||||||
|
scores: np.ndarray,
|
||||||
|
y_next: np.ndarray,
|
||||||
|
tau_short: float,
|
||||||
|
min_pos: int,
|
||||||
|
) -> tuple[list[dict], pd.DataFrame]:
|
||||||
|
"""Compute next-event metrics on a given subset.
|
||||||
|
|
||||||
|
Definitions are unchanged from the original script.
|
||||||
|
"""
|
||||||
|
n_records_total = int(y_next.size)
|
||||||
|
eligible = y_next >= 0
|
||||||
|
n_eligible = int(eligible.sum())
|
||||||
|
coverage = float(
|
||||||
|
n_eligible / n_records_total) if n_records_total > 0 else 0.0
|
||||||
|
|
||||||
|
metrics_rows: List[dict] = []
|
||||||
|
metrics_rows.append({"metric": "n_records_total", "value": n_records_total})
|
||||||
|
metrics_rows.append(
|
||||||
|
{"metric": "n_next_event_eligible", "value": n_eligible})
|
||||||
|
metrics_rows.append({"metric": "coverage", "value": coverage})
|
||||||
|
metrics_rows.append(
|
||||||
|
{"metric": "tau_short_years", "value": float(tau_short)})
|
||||||
|
|
||||||
|
K = int(scores.shape[1])
|
||||||
|
if n_records_total == 0:
|
||||||
|
per_cause_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"cause_id": np.arange(K, dtype=np.int64),
|
||||||
|
"n_pos": np.zeros((K,), dtype=np.int64),
|
||||||
|
"n_neg": np.zeros((K,), dtype=np.int64),
|
||||||
|
"auc": np.full((K,), np.nan, dtype=np.float64),
|
||||||
|
"included": np.zeros((K,), dtype=bool),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
|
||||||
|
metrics_rows.append({"metric": "mrr", "value": float("nan")})
|
||||||
|
for k in [1, 3, 5, 10, 20]:
|
||||||
|
metrics_rows.append(
|
||||||
|
{"metric": f"hitrate_at_{k}", "value": float("nan")})
|
||||||
|
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
||||||
|
return metrics_rows, per_cause_df
|
||||||
|
|
||||||
|
# If no eligible, keep coverage but leave accuracy-like metrics as NaN.
|
||||||
|
if n_eligible == 0:
|
||||||
|
per_cause_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"cause_id": np.arange(K, dtype=np.int64),
|
||||||
|
"n_pos": np.zeros((K,), dtype=np.int64),
|
||||||
|
"n_neg": np.full((K,), n_records_total, dtype=np.int64),
|
||||||
|
"auc": np.full((K,), np.nan, dtype=np.float64),
|
||||||
|
"included": np.zeros((K,), dtype=bool),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
metrics_rows.append({"metric": "top1_accuracy", "value": float("nan")})
|
||||||
|
metrics_rows.append({"metric": "mrr", "value": float("nan")})
|
||||||
|
for k in [1, 3, 5, 10, 20]:
|
||||||
|
metrics_rows.append(
|
||||||
|
{"metric": f"hitrate_at_{k}", "value": float("nan")})
|
||||||
|
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
||||||
|
return metrics_rows, per_cause_df
|
||||||
|
|
||||||
|
scores_e = scores[eligible]
|
||||||
|
y_e = y_next[eligible]
|
||||||
|
|
||||||
|
pred = scores_e.argmax(axis=1)
|
||||||
|
acc = float((pred == y_e).mean())
|
||||||
|
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
|
||||||
|
|
||||||
|
# MRR
|
||||||
|
order = np.argsort(-scores_e, axis=1, kind="mergesort")
|
||||||
|
ranks = np.empty(y_e.shape[0], dtype=np.int32)
|
||||||
|
for i in range(y_e.shape[0]):
|
||||||
|
ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1
|
||||||
|
mrr = float((1.0 / ranks).mean())
|
||||||
|
metrics_rows.append({"metric": "mrr", "value": mrr})
|
||||||
|
|
||||||
|
# HitRate@K
|
||||||
|
for k in [1, 3, 5, 10, 20]:
|
||||||
|
topk = topk_indices(scores_e, k)
|
||||||
|
hit = (topk == y_e[:, None]).any(axis=1)
|
||||||
|
metrics_rows.append({"metric": f"hitrate_at_{k}",
|
||||||
|
"value": float(hit.mean())})
|
||||||
|
|
||||||
|
# Macro OvR AUC per cause (optional)
|
||||||
|
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
|
||||||
|
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
|
||||||
|
|
||||||
|
auc = np.full((K,), np.nan, dtype=np.float64)
|
||||||
|
candidates = np.flatnonzero((n_pos >= int(min_pos)) & (n_neg > 0))
|
||||||
|
for k in candidates:
|
||||||
|
auc[k] = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
|
||||||
|
|
||||||
|
included = (n_pos >= int(min_pos)) & (n_neg > 0)
|
||||||
|
per_cause_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"cause_id": np.arange(K, dtype=np.int64),
|
||||||
|
"n_pos": n_pos,
|
||||||
|
"n_neg": n_neg,
|
||||||
|
"auc": auc,
|
||||||
|
"included": included,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
aucs = auc[np.isfinite(auc)]
|
||||||
|
if aucs.size > 0:
|
||||||
|
metrics_rows.append(
|
||||||
|
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
|
||||||
|
else:
|
||||||
|
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
||||||
|
|
||||||
|
return metrics_rows, per_cause_df
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
@@ -118,92 +240,68 @@ def main() -> None:
|
|||||||
if scores.ndim == 3:
|
if scores.ndim == 3:
|
||||||
scores = scores[:, :, 0]
|
scores = scores[:, :, 0]
|
||||||
|
|
||||||
n_records_total = len(records)
|
|
||||||
y_next = np.array(
|
y_next = np.array(
|
||||||
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
|
[(-1 if r.next_event_cause is None else int(r.next_event_cause))
|
||||||
for r in records],
|
for r in records],
|
||||||
dtype=np.int64,
|
dtype=np.int64,
|
||||||
)
|
)
|
||||||
eligible = y_next >= 0
|
|
||||||
n_eligible = int(eligible.sum())
|
|
||||||
coverage = float(
|
|
||||||
n_eligible / n_records_total) if n_records_total > 0 else 0.0
|
|
||||||
|
|
||||||
metrics_rows: List[dict] = []
|
# Overall (preserve existing output files/shape)
|
||||||
metrics_rows.append({"metric": "n_records_total", "value": n_records_total})
|
metrics_rows, per_cause_df = _compute_next_event_metrics(
|
||||||
metrics_rows.append(
|
scores=scores,
|
||||||
{"metric": "n_next_event_eligible", "value": n_eligible})
|
y_next=y_next,
|
||||||
metrics_rows.append({"metric": "coverage", "value": coverage})
|
tau_short=tau,
|
||||||
metrics_rows.append({"metric": "tau_short_years", "value": tau})
|
min_pos=int(args.min_pos),
|
||||||
|
|
||||||
if n_eligible == 0:
|
|
||||||
out_path = os.path.join(run_dir, "next_event_metrics.csv")
|
|
||||||
pd.DataFrame(metrics_rows).to_csv(out_path, index=False)
|
|
||||||
print(f"No eligible records; wrote {out_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
scores_e = scores[eligible]
|
|
||||||
y_e = y_next[eligible]
|
|
||||||
|
|
||||||
pred = scores_e.argmax(axis=1)
|
|
||||||
acc = float((pred == y_e).mean())
|
|
||||||
metrics_rows.append({"metric": "top1_accuracy", "value": acc})
|
|
||||||
|
|
||||||
# MRR
|
|
||||||
order = np.argsort(-scores_e, axis=1, kind="mergesort")
|
|
||||||
ranks = np.empty(y_e.shape[0], dtype=np.int32)
|
|
||||||
for i in range(y_e.shape[0]):
|
|
||||||
ranks[i] = int(np.where(order[i] == y_e[i])[0][0]) + 1
|
|
||||||
mrr = float((1.0 / ranks).mean())
|
|
||||||
metrics_rows.append({"metric": "mrr", "value": mrr})
|
|
||||||
|
|
||||||
# HitRate@K
|
|
||||||
for k in [1, 3, 5, 10, 20]:
|
|
||||||
topk = topk_indices(scores_e, k)
|
|
||||||
hit = (topk == y_e[:, None]).any(axis=1)
|
|
||||||
metrics_rows.append({"metric": f"hitrate_at_{k}",
|
|
||||||
"value": float(hit.mean())})
|
|
||||||
|
|
||||||
# Macro OvR AUC per cause (optional)
|
|
||||||
K = scores.shape[1]
|
|
||||||
n_pos = np.bincount(y_e, minlength=K).astype(np.int64)
|
|
||||||
n_neg = (int(y_e.size) - n_pos).astype(np.int64)
|
|
||||||
|
|
||||||
auc = np.full((K,), np.nan, dtype=np.float64)
|
|
||||||
min_pos = int(args.min_pos)
|
|
||||||
candidates = np.flatnonzero((n_pos >= min_pos) & (n_neg > 0))
|
|
||||||
for k in candidates:
|
|
||||||
auc_k = roc_auc_ovr((y_e == k).astype(np.int32), scores_e[:, k])
|
|
||||||
auc[k] = auc_k
|
|
||||||
|
|
||||||
included = (n_pos >= min_pos) & (n_neg > 0)
|
|
||||||
per_cause_df = pd.DataFrame(
|
|
||||||
{
|
|
||||||
"cause_id": np.arange(K, dtype=np.int64),
|
|
||||||
"n_pos": n_pos,
|
|
||||||
"n_neg": n_neg,
|
|
||||||
"auc": auc,
|
|
||||||
"included": included,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
aucs = auc[np.isfinite(auc)]
|
|
||||||
|
|
||||||
if aucs.size > 0:
|
|
||||||
metrics_rows.append(
|
|
||||||
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
|
|
||||||
else:
|
|
||||||
metrics_rows.append({"metric": "macro_ovr_auc", "value": float("nan")})
|
|
||||||
|
|
||||||
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
|
out_metrics = os.path.join(run_dir, "next_event_metrics.csv")
|
||||||
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
|
pd.DataFrame(metrics_rows).to_csv(out_metrics, index=False)
|
||||||
|
|
||||||
# optional per-cause
|
|
||||||
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
|
out_pc = os.path.join(run_dir, "next_event_per_cause.csv")
|
||||||
per_cause_df.to_csv(out_pc, index=False)
|
per_cause_df.to_csv(out_pc, index=False)
|
||||||
|
|
||||||
|
# By age bin (new outputs)
|
||||||
|
age_bins_years = np.asarray(age_bins_years, dtype=np.float64)
|
||||||
|
age_bins_days = age_bins_years * DAYS_PER_YEAR
|
||||||
|
# Bin assignment from t0 (constructed within the bin): [b_i, b_{i+1})
|
||||||
|
t0_days = np.asarray([float(r.t0_days) for r in records], dtype=np.float64)
|
||||||
|
bin_idx = np.searchsorted(age_bins_days, t0_days, side="left") - 1
|
||||||
|
|
||||||
|
per_bin_metric_rows: List[dict] = []
|
||||||
|
per_bin_cause_parts: List[pd.DataFrame] = []
|
||||||
|
for b in range(len(age_bins_years) - 1):
|
||||||
|
lo = float(age_bins_years[b])
|
||||||
|
hi = float(age_bins_years[b + 1])
|
||||||
|
label = _format_age_bin_label(lo, hi)
|
||||||
|
m = bin_idx == b
|
||||||
|
m_scores = scores[m]
|
||||||
|
m_y = y_next[m]
|
||||||
|
m_rows, m_pc = _compute_next_event_metrics(
|
||||||
|
scores=m_scores,
|
||||||
|
y_next=m_y,
|
||||||
|
tau_short=tau,
|
||||||
|
min_pos=int(args.min_pos),
|
||||||
|
)
|
||||||
|
for row in m_rows:
|
||||||
|
per_bin_metric_rows.append({"age_bin": label, **row})
|
||||||
|
m_pc = m_pc.copy()
|
||||||
|
m_pc.insert(0, "age_bin", label)
|
||||||
|
per_bin_cause_parts.append(m_pc)
|
||||||
|
|
||||||
|
out_metrics_bins = os.path.join(
|
||||||
|
run_dir, "next_event_metrics_by_age_bin.csv")
|
||||||
|
pd.DataFrame(per_bin_metric_rows).to_csv(out_metrics_bins, index=False)
|
||||||
|
out_pc_bins = os.path.join(run_dir, "next_event_per_cause_by_age_bin.csv")
|
||||||
|
if per_bin_cause_parts:
|
||||||
|
pd.concat(per_bin_cause_parts, ignore_index=True).to_csv(
|
||||||
|
out_pc_bins, index=False)
|
||||||
|
else:
|
||||||
|
pd.DataFrame(columns=["age_bin", "cause_id", "n_pos", "n_neg",
|
||||||
|
"auc", "included"]).to_csv(out_pc_bins, index=False)
|
||||||
|
|
||||||
print(f"Wrote {out_metrics}")
|
print(f"Wrote {out_metrics}")
|
||||||
print(f"Wrote {out_pc}")
|
print(f"Wrote {out_pc}")
|
||||||
|
print(f"Wrote {out_metrics_bins}")
|
||||||
|
print(f"Wrote {out_pc_bins}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user