Add tqdm progress bar support and disable option for evaluation scripts

This commit is contained in:
2026-01-17 14:00:42 +08:00
parent bfab601a77
commit 67f92ce6c4
3 changed files with 79 additions and 7 deletions

View File

@@ -7,6 +7,11 @@ import pandas as pd
import torch
from torch.utils.data import DataLoader
try:
from tqdm import tqdm # noqa: F401
except Exception: # pragma: no cover
tqdm = None
from utils import (
EvalRecordDataset,
build_dataset_from_config,
@@ -53,6 +58,11 @@ def parse_args() -> argparse.Namespace:
default=20,
help="Minimum positives for per-cause AUC",
)
p.add_argument(
"--no_tqdm",
action="store_true",
help="Disable tqdm progress bars",
)
return p.parse_args()
@@ -60,6 +70,8 @@ def main() -> None:
args = parse_args()
seed_everything(args.seed)
show_progress = (not args.no_tqdm)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
@@ -72,6 +84,7 @@ def main() -> None:
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
show_progress=show_progress,
)
device = torch.device(args.device)
@@ -91,7 +104,16 @@ def main() -> None:
)
tau = float(args.tau_short)
scores = predict_cifs(model, head, criterion, loader, [tau], device=device)
scores = predict_cifs(
model,
head,
criterion,
loader,
[tau],
device=device,
show_progress=show_progress,
progress_desc="Inference (next-event)",
)
# scores shape: (N,K,1) for multi-taus; squeeze last
if scores.ndim == 3:
scores = scores[:, :, 0]
@@ -167,7 +189,7 @@ def main() -> None:
aucs = auc[np.isfinite(auc)]
if aucs:
if aucs.size > 0:
metrics_rows.append(
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
else: