Add tqdm progress bar support and disable option for evaluation scripts
This commit is contained in:
@@ -7,6 +7,11 @@ import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
try:
|
||||
from tqdm import tqdm # noqa: F401
|
||||
except Exception: # pragma: no cover
|
||||
tqdm = None
|
||||
|
||||
from utils import (
|
||||
EvalRecordDataset,
|
||||
build_dataset_from_config,
|
||||
@@ -53,6 +58,11 @@ def parse_args() -> argparse.Namespace:
|
||||
default=20,
|
||||
help="Minimum positives for per-cause AUC",
|
||||
)
|
||||
p.add_argument(
|
||||
"--no_tqdm",
|
||||
action="store_true",
|
||||
help="Disable tqdm progress bars",
|
||||
)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
@@ -60,6 +70,8 @@ def main() -> None:
|
||||
args = parse_args()
|
||||
seed_everything(args.seed)
|
||||
|
||||
show_progress = (not args.no_tqdm)
|
||||
|
||||
run_dir = args.run_dir
|
||||
cfg = load_train_config(run_dir)
|
||||
|
||||
@@ -72,6 +84,7 @@ def main() -> None:
|
||||
subset=test_subset,
|
||||
age_bins_years=age_bins_years,
|
||||
seed=args.seed,
|
||||
show_progress=show_progress,
|
||||
)
|
||||
|
||||
device = torch.device(args.device)
|
||||
@@ -91,7 +104,16 @@ def main() -> None:
|
||||
)
|
||||
|
||||
tau = float(args.tau_short)
|
||||
scores = predict_cifs(model, head, criterion, loader, [tau], device=device)
|
||||
scores = predict_cifs(
|
||||
model,
|
||||
head,
|
||||
criterion,
|
||||
loader,
|
||||
[tau],
|
||||
device=device,
|
||||
show_progress=show_progress,
|
||||
progress_desc="Inference (next-event)",
|
||||
)
|
||||
# scores shape: (N,K,1) for multi-taus; squeeze last
|
||||
if scores.ndim == 3:
|
||||
scores = scores[:, :, 0]
|
||||
@@ -167,7 +189,7 @@ def main() -> None:
|
||||
|
||||
aucs = auc[np.isfinite(auc)]
|
||||
|
||||
if aucs:
|
||||
if aucs.size > 0:
|
||||
metrics_rows.append(
|
||||
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user