From 67f92ce6c4760ab1bafee85c6a06e23295d593f1 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sat, 17 Jan 2026 14:00:42 +0800 Subject: [PATCH] Add tqdm progress bar support and disable option for evaluation scripts --- evaluate_horizon.py | 32 +++++++++++++++++++++++++++++--- evaluate_next_event.py | 26 ++++++++++++++++++++++++-- utils.py | 28 ++++++++++++++++++++++++++-- 3 files changed, 79 insertions(+), 7 deletions(-) diff --git a/evaluate_horizon.py b/evaluate_horizon.py index 6ce9fe7..b4c0bc5 100644 --- a/evaluate_horizon.py +++ b/evaluate_horizon.py @@ -19,6 +19,11 @@ import pandas as pd import torch from torch.utils.data import DataLoader +try: + from tqdm import tqdm +except Exception: # pragma: no cover + tqdm = None + from utils import ( EvalRecordDataset, build_dataset_from_config, @@ -72,6 +77,11 @@ def parse_args() -> argparse.Namespace: nargs="+", default=[1, 5, 10, 20, 50], ) + p.add_argument( + "--no_tqdm", + action="store_true", + help="Disable tqdm progress bars", + ) return p.parse_args() @@ -102,6 +112,8 @@ def main() -> None: args = parse_args() seed_everything(args.seed) + show_progress = (not args.no_tqdm) + run_dir = args.run_dir cfg = load_train_config(run_dir) @@ -117,6 +129,7 @@ def main() -> None: subset=test_subset, age_bins_years=age_bins_years, seed=args.seed, + show_progress=show_progress, ) device = torch.device(args.device) @@ -139,8 +152,16 @@ def main() -> None: print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).") print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).") - scores = predict_cifs(model, head, criterion, loader, - horizons, device=device) + scores = predict_cifs( + model, + head, + criterion, + loader, + horizons, + device=device, + show_progress=show_progress, + progress_desc="Inference (horizons)", + ) # scores shape: (N, K, H) if scores.ndim != 3: raise ValueError( @@ -157,7 +178,12 @@ def main() -> None: per_cause_rows: List[Dict[str, object]] = [] workload_rows: List[Dict[str, object]] = [] - for h_idx, tau in enumerate(horizons): + horizon_iter = enumerate(horizons) + if show_progress and tqdm is not None: + horizon_iter = tqdm(horizon_iter, total=len( + horizons), desc="Metrics by horizon") + + for h_idx, tau in horizon_iter: s_tau = scores[:, :, h_idx] y_tau = build_labels_within_tau_flat( N, K, evt_rec_idx, evt_cause, evt_dt, tau) diff --git a/evaluate_next_event.py b/evaluate_next_event.py index 92124dd..6c811e5 100644 --- a/evaluate_next_event.py +++ b/evaluate_next_event.py @@ -7,6 +7,11 @@ import pandas as pd import torch from torch.utils.data import DataLoader +try: + from tqdm import tqdm # noqa: F401 +except Exception: # pragma: no cover + tqdm = None + from utils import ( EvalRecordDataset, build_dataset_from_config, @@ -53,6 +58,11 @@ def parse_args() -> argparse.Namespace: default=20, help="Minimum positives for per-cause AUC", ) + p.add_argument( + "--no_tqdm", + action="store_true", + help="Disable tqdm progress bars", + ) return p.parse_args() @@ -60,6 +70,8 @@ def main() -> None: args = parse_args() seed_everything(args.seed) + show_progress = (not args.no_tqdm) + run_dir = args.run_dir cfg = load_train_config(run_dir) @@ -72,6 +84,7 @@ def main() -> None: subset=test_subset, age_bins_years=age_bins_years, seed=args.seed, + show_progress=show_progress, ) device = torch.device(args.device) @@ -91,7 +104,16 @@ def main() -> None: ) tau = float(args.tau_short) - scores = predict_cifs(model, head, criterion, loader, [tau], device=device) + scores = predict_cifs( + model, + head, + criterion, + loader, + [tau], + device=device, + show_progress=show_progress, + progress_desc="Inference (next-event)", + ) # scores shape: (N,K,1) for multi-taus; squeeze last if scores.ndim == 3: scores = scores[:, :, 0] @@ -167,7 +189,7 @@ def main() -> None: aucs = auc[np.isfinite(auc)] - if aucs: + if aucs.size > 0: metrics_rows.append( {"metric": "macro_ovr_auc", "value": float(np.mean(aucs))}) else: diff --git a/utils.py b/utils.py index d05c4e9..c17ec7d 100644 --- a/utils.py +++ b/utils.py @@ -10,6 +10,11 @@ import numpy as np import torch from torch.utils.data import DataLoader, Dataset, Subset, random_split +try: + from tqdm import tqdm as _tqdm +except Exception: # pragma: no cover + _tqdm = None + from dataset import HealthDataset from losses import ( DiscreteTimeCIFNLLLoss, @@ -23,6 +28,12 @@ DAYS_PER_YEAR = 365.25 N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2 +def _progress(iterable, *, enabled: bool, desc: str, total: Optional[int] = None): + if enabled and _tqdm is not None: + return _tqdm(iterable, desc=desc, total=total) + return iterable + + def make_inference_dataloader_kwargs( device: torch.device, num_workers: int, @@ -278,6 +289,7 @@ def build_event_driven_records( subset: Subset, age_bins_years: Sequence[float], seed: int, + show_progress: bool = False, ) -> List[EvalRecord]: if len(age_bins_years) < 2: raise ValueError("age_bins must have at least 2 boundaries") @@ -296,7 +308,12 @@ def build_event_driven_records( # Speed: avoid calling dataset.__getitem__ for every patient here. # We only need DOA + event times/codes to create evaluation records. eps = 1e-6 - for patient_idx in indices: + for patient_idx in _progress( + indices, + enabled=show_progress, + desc="Building eval records", + total=len(indices), + ): patient_id = dataset.patient_ids[patient_idx] doa_days = float(dataset._doa[patient_idx]) @@ -445,6 +462,8 @@ def predict_cifs( loader: DataLoader, taus_years: Sequence[float], device: torch.device, + show_progress: bool = False, + progress_desc: str = "Inference", ) -> np.ndarray: model.eval() head.eval() @@ -453,7 +472,12 @@ def predict_cifs( all_out: List[np.ndarray] = [] with torch.no_grad(): - for batch in loader: + for batch in _progress( + loader, + enabled=show_progress, + desc=progress_desc, + total=len(loader) if hasattr(loader, "__len__") else None, + ): event_seq, time_seq, cont, cate, sex, baseline_pos = batch event_seq = event_seq.to(device, non_blocking=True) time_seq = time_seq.to(device, non_blocking=True)