Add tqdm progress bar support and disable option for evaluation scripts

This commit is contained in:
2026-01-17 14:00:42 +08:00
parent bfab601a77
commit 67f92ce6c4
3 changed files with 79 additions and 7 deletions

View File

@@ -19,6 +19,11 @@ import pandas as pd
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
try:
from tqdm import tqdm
except Exception: # pragma: no cover
tqdm = None
from utils import ( from utils import (
EvalRecordDataset, EvalRecordDataset,
build_dataset_from_config, build_dataset_from_config,
@@ -72,6 +77,11 @@ def parse_args() -> argparse.Namespace:
nargs="+", nargs="+",
default=[1, 5, 10, 20, 50], default=[1, 5, 10, 20, 50],
) )
p.add_argument(
"--no_tqdm",
action="store_true",
help="Disable tqdm progress bars",
)
return p.parse_args() return p.parse_args()
@@ -102,6 +112,8 @@ def main() -> None:
args = parse_args() args = parse_args()
seed_everything(args.seed) seed_everything(args.seed)
show_progress = (not args.no_tqdm)
run_dir = args.run_dir run_dir = args.run_dir
cfg = load_train_config(run_dir) cfg = load_train_config(run_dir)
@@ -117,6 +129,7 @@ def main() -> None:
subset=test_subset, subset=test_subset,
age_bins_years=age_bins_years, age_bins_years=age_bins_years,
seed=args.seed, seed=args.seed,
show_progress=show_progress,
) )
device = torch.device(args.device) device = torch.device(args.device)
@@ -139,8 +152,16 @@ def main() -> None:
print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).") print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).")
print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).") print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).")
scores = predict_cifs(model, head, criterion, loader, scores = predict_cifs(
horizons, device=device) model,
head,
criterion,
loader,
horizons,
device=device,
show_progress=show_progress,
progress_desc="Inference (horizons)",
)
# scores shape: (N, K, H) # scores shape: (N, K, H)
if scores.ndim != 3: if scores.ndim != 3:
raise ValueError( raise ValueError(
@@ -157,7 +178,12 @@ def main() -> None:
per_cause_rows: List[Dict[str, object]] = [] per_cause_rows: List[Dict[str, object]] = []
workload_rows: List[Dict[str, object]] = [] workload_rows: List[Dict[str, object]] = []
for h_idx, tau in enumerate(horizons): horizon_iter = enumerate(horizons)
if show_progress and tqdm is not None:
horizon_iter = tqdm(horizon_iter, total=len(
horizons), desc="Metrics by horizon")
for h_idx, tau in horizon_iter:
s_tau = scores[:, :, h_idx] s_tau = scores[:, :, h_idx]
y_tau = build_labels_within_tau_flat( y_tau = build_labels_within_tau_flat(
N, K, evt_rec_idx, evt_cause, evt_dt, tau) N, K, evt_rec_idx, evt_cause, evt_dt, tau)

View File

@@ -7,6 +7,11 @@ import pandas as pd
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
try:
from tqdm import tqdm # noqa: F401
except Exception: # pragma: no cover
tqdm = None
from utils import ( from utils import (
EvalRecordDataset, EvalRecordDataset,
build_dataset_from_config, build_dataset_from_config,
@@ -53,6 +58,11 @@ def parse_args() -> argparse.Namespace:
default=20, default=20,
help="Minimum positives for per-cause AUC", help="Minimum positives for per-cause AUC",
) )
p.add_argument(
"--no_tqdm",
action="store_true",
help="Disable tqdm progress bars",
)
return p.parse_args() return p.parse_args()
@@ -60,6 +70,8 @@ def main() -> None:
args = parse_args() args = parse_args()
seed_everything(args.seed) seed_everything(args.seed)
show_progress = (not args.no_tqdm)
run_dir = args.run_dir run_dir = args.run_dir
cfg = load_train_config(run_dir) cfg = load_train_config(run_dir)
@@ -72,6 +84,7 @@ def main() -> None:
subset=test_subset, subset=test_subset,
age_bins_years=age_bins_years, age_bins_years=age_bins_years,
seed=args.seed, seed=args.seed,
show_progress=show_progress,
) )
device = torch.device(args.device) device = torch.device(args.device)
@@ -91,7 +104,16 @@ def main() -> None:
) )
tau = float(args.tau_short) tau = float(args.tau_short)
scores = predict_cifs(model, head, criterion, loader, [tau], device=device) scores = predict_cifs(
model,
head,
criterion,
loader,
[tau],
device=device,
show_progress=show_progress,
progress_desc="Inference (next-event)",
)
# scores shape: (N,K,1) for multi-taus; squeeze last # scores shape: (N,K,1) for multi-taus; squeeze last
if scores.ndim == 3: if scores.ndim == 3:
scores = scores[:, :, 0] scores = scores[:, :, 0]
@@ -167,7 +189,7 @@ def main() -> None:
aucs = auc[np.isfinite(auc)] aucs = auc[np.isfinite(auc)]
if aucs: if aucs.size > 0:
metrics_rows.append( metrics_rows.append(
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))}) {"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
else: else:

View File

@@ -10,6 +10,11 @@ import numpy as np
import torch import torch
from torch.utils.data import DataLoader, Dataset, Subset, random_split from torch.utils.data import DataLoader, Dataset, Subset, random_split
try:
from tqdm import tqdm as _tqdm
except Exception: # pragma: no cover
_tqdm = None
from dataset import HealthDataset from dataset import HealthDataset
from losses import ( from losses import (
DiscreteTimeCIFNLLLoss, DiscreteTimeCIFNLLLoss,
@@ -23,6 +28,12 @@ DAYS_PER_YEAR = 365.25
N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2 N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2
def _progress(iterable, *, enabled: bool, desc: str, total: Optional[int] = None):
if enabled and _tqdm is not None:
return _tqdm(iterable, desc=desc, total=total)
return iterable
def make_inference_dataloader_kwargs( def make_inference_dataloader_kwargs(
device: torch.device, device: torch.device,
num_workers: int, num_workers: int,
@@ -278,6 +289,7 @@ def build_event_driven_records(
subset: Subset, subset: Subset,
age_bins_years: Sequence[float], age_bins_years: Sequence[float],
seed: int, seed: int,
show_progress: bool = False,
) -> List[EvalRecord]: ) -> List[EvalRecord]:
if len(age_bins_years) < 2: if len(age_bins_years) < 2:
raise ValueError("age_bins must have at least 2 boundaries") raise ValueError("age_bins must have at least 2 boundaries")
@@ -296,7 +308,12 @@ def build_event_driven_records(
# Speed: avoid calling dataset.__getitem__ for every patient here. # Speed: avoid calling dataset.__getitem__ for every patient here.
# We only need DOA + event times/codes to create evaluation records. # We only need DOA + event times/codes to create evaluation records.
eps = 1e-6 eps = 1e-6
for patient_idx in indices: for patient_idx in _progress(
indices,
enabled=show_progress,
desc="Building eval records",
total=len(indices),
):
patient_id = dataset.patient_ids[patient_idx] patient_id = dataset.patient_ids[patient_idx]
doa_days = float(dataset._doa[patient_idx]) doa_days = float(dataset._doa[patient_idx])
@@ -445,6 +462,8 @@ def predict_cifs(
loader: DataLoader, loader: DataLoader,
taus_years: Sequence[float], taus_years: Sequence[float],
device: torch.device, device: torch.device,
show_progress: bool = False,
progress_desc: str = "Inference",
) -> np.ndarray: ) -> np.ndarray:
model.eval() model.eval()
head.eval() head.eval()
@@ -453,7 +472,12 @@ def predict_cifs(
all_out: List[np.ndarray] = [] all_out: List[np.ndarray] = []
with torch.no_grad(): with torch.no_grad():
for batch in loader: for batch in _progress(
loader,
enabled=show_progress,
desc=progress_desc,
total=len(loader) if hasattr(loader, "__len__") else None,
):
event_seq, time_seq, cont, cate, sex, baseline_pos = batch event_seq, time_seq, cont, cate, sex, baseline_pos = batch
event_seq = event_seq.to(device, non_blocking=True) event_seq = event_seq.to(device, non_blocking=True)
time_seq = time_seq.to(device, non_blocking=True) time_seq = time_seq.to(device, non_blocking=True)