Add tqdm progress bar support and disable option for evaluation scripts

This commit is contained in:
2026-01-17 14:00:42 +08:00
parent bfab601a77
commit 67f92ce6c4
3 changed files with 79 additions and 7 deletions

View File

@@ -19,6 +19,11 @@ import pandas as pd
import torch
from torch.utils.data import DataLoader
try:
from tqdm import tqdm
except Exception: # pragma: no cover
tqdm = None
from utils import (
EvalRecordDataset,
build_dataset_from_config,
@@ -72,6 +77,11 @@ def parse_args() -> argparse.Namespace:
nargs="+",
default=[1, 5, 10, 20, 50],
)
p.add_argument(
"--no_tqdm",
action="store_true",
help="Disable tqdm progress bars",
)
return p.parse_args()
@@ -102,6 +112,8 @@ def main() -> None:
args = parse_args()
seed_everything(args.seed)
show_progress = (not args.no_tqdm)
run_dir = args.run_dir
cfg = load_train_config(run_dir)
@@ -117,6 +129,7 @@ def main() -> None:
subset=test_subset,
age_bins_years=age_bins_years,
seed=args.seed,
show_progress=show_progress,
)
device = torch.device(args.device)
@@ -139,8 +152,16 @@ def main() -> None:
print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).")
print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).")
scores = predict_cifs(model, head, criterion, loader,
horizons, device=device)
scores = predict_cifs(
model,
head,
criterion,
loader,
horizons,
device=device,
show_progress=show_progress,
progress_desc="Inference (horizons)",
)
# scores shape: (N, K, H)
if scores.ndim != 3:
raise ValueError(
@@ -157,7 +178,12 @@ def main() -> None:
per_cause_rows: List[Dict[str, object]] = []
workload_rows: List[Dict[str, object]] = []
for h_idx, tau in enumerate(horizons):
horizon_iter = enumerate(horizons)
if show_progress and tqdm is not None:
horizon_iter = tqdm(horizon_iter, total=len(
horizons), desc="Metrics by horizon")
for h_idx, tau in horizon_iter:
s_tau = scores[:, :, h_idx]
y_tau = build_labels_within_tau_flat(
N, K, evt_rec_idx, evt_cause, evt_dt, tau)