Add tqdm progress bar support and disable option for evaluation scripts

This commit is contained in:
2026-01-17 14:00:42 +08:00
parent bfab601a77
commit 67f92ce6c4
3 changed files with 79 additions and 7 deletions

View File

@@ -10,6 +10,11 @@ import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, Subset, random_split
try:
from tqdm import tqdm as _tqdm
except Exception: # pragma: no cover
_tqdm = None
from dataset import HealthDataset
from losses import (
DiscreteTimeCIFNLLLoss,
@@ -23,6 +28,12 @@ DAYS_PER_YEAR = 365.25
N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2
def _progress(iterable, *, enabled: bool, desc: str, total: Optional[int] = None):
if enabled and _tqdm is not None:
return _tqdm(iterable, desc=desc, total=total)
return iterable
def make_inference_dataloader_kwargs(
device: torch.device,
num_workers: int,
@@ -278,6 +289,7 @@ def build_event_driven_records(
subset: Subset,
age_bins_years: Sequence[float],
seed: int,
show_progress: bool = False,
) -> List[EvalRecord]:
if len(age_bins_years) < 2:
raise ValueError("age_bins must have at least 2 boundaries")
@@ -296,7 +308,12 @@ def build_event_driven_records(
# Speed: avoid calling dataset.__getitem__ for every patient here.
# We only need DOA + event times/codes to create evaluation records.
eps = 1e-6
for patient_idx in indices:
for patient_idx in _progress(
indices,
enabled=show_progress,
desc="Building eval records",
total=len(indices),
):
patient_id = dataset.patient_ids[patient_idx]
doa_days = float(dataset._doa[patient_idx])
@@ -445,6 +462,8 @@ def predict_cifs(
loader: DataLoader,
taus_years: Sequence[float],
device: torch.device,
show_progress: bool = False,
progress_desc: str = "Inference",
) -> np.ndarray:
model.eval()
head.eval()
@@ -453,7 +472,12 @@ def predict_cifs(
all_out: List[np.ndarray] = []
with torch.no_grad():
for batch in loader:
for batch in _progress(
loader,
enabled=show_progress,
desc=progress_desc,
total=len(loader) if hasattr(loader, "__len__") else None,
):
event_seq, time_seq, cont, cate, sex, baseline_pos = batch
event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True)