Add tqdm progress bar support and disable option for evaluation scripts
This commit is contained in:
@@ -19,6 +19,11 @@ import pandas as pd
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
tqdm = None
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
EvalRecordDataset,
|
EvalRecordDataset,
|
||||||
build_dataset_from_config,
|
build_dataset_from_config,
|
||||||
@@ -72,6 +77,11 @@ def parse_args() -> argparse.Namespace:
|
|||||||
nargs="+",
|
nargs="+",
|
||||||
default=[1, 5, 10, 20, 50],
|
default=[1, 5, 10, 20, 50],
|
||||||
)
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--no_tqdm",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable tqdm progress bars",
|
||||||
|
)
|
||||||
return p.parse_args()
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -102,6 +112,8 @@ def main() -> None:
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
|
|
||||||
|
show_progress = (not args.no_tqdm)
|
||||||
|
|
||||||
run_dir = args.run_dir
|
run_dir = args.run_dir
|
||||||
cfg = load_train_config(run_dir)
|
cfg = load_train_config(run_dir)
|
||||||
|
|
||||||
@@ -117,6 +129,7 @@ def main() -> None:
|
|||||||
subset=test_subset,
|
subset=test_subset,
|
||||||
age_bins_years=age_bins_years,
|
age_bins_years=age_bins_years,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
|
show_progress=show_progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
@@ -139,8 +152,16 @@ def main() -> None:
|
|||||||
print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).")
|
print("DISCLAIMER: AUC here is horizon-dependent label AUC (no IPCW / censoring adjustment).")
|
||||||
print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).")
|
print("DISCLAIMER: Brier is unadjusted diagnostic/proxy (no censoring adjustment).")
|
||||||
|
|
||||||
scores = predict_cifs(model, head, criterion, loader,
|
scores = predict_cifs(
|
||||||
horizons, device=device)
|
model,
|
||||||
|
head,
|
||||||
|
criterion,
|
||||||
|
loader,
|
||||||
|
horizons,
|
||||||
|
device=device,
|
||||||
|
show_progress=show_progress,
|
||||||
|
progress_desc="Inference (horizons)",
|
||||||
|
)
|
||||||
# scores shape: (N, K, H)
|
# scores shape: (N, K, H)
|
||||||
if scores.ndim != 3:
|
if scores.ndim != 3:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -157,7 +178,12 @@ def main() -> None:
|
|||||||
per_cause_rows: List[Dict[str, object]] = []
|
per_cause_rows: List[Dict[str, object]] = []
|
||||||
workload_rows: List[Dict[str, object]] = []
|
workload_rows: List[Dict[str, object]] = []
|
||||||
|
|
||||||
for h_idx, tau in enumerate(horizons):
|
horizon_iter = enumerate(horizons)
|
||||||
|
if show_progress and tqdm is not None:
|
||||||
|
horizon_iter = tqdm(horizon_iter, total=len(
|
||||||
|
horizons), desc="Metrics by horizon")
|
||||||
|
|
||||||
|
for h_idx, tau in horizon_iter:
|
||||||
s_tau = scores[:, :, h_idx]
|
s_tau = scores[:, :, h_idx]
|
||||||
y_tau = build_labels_within_tau_flat(
|
y_tau = build_labels_within_tau_flat(
|
||||||
N, K, evt_rec_idx, evt_cause, evt_dt, tau)
|
N, K, evt_rec_idx, evt_cause, evt_dt, tau)
|
||||||
|
|||||||
@@ -7,6 +7,11 @@ import pandas as pd
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm # noqa: F401
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
tqdm = None
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
EvalRecordDataset,
|
EvalRecordDataset,
|
||||||
build_dataset_from_config,
|
build_dataset_from_config,
|
||||||
@@ -53,6 +58,11 @@ def parse_args() -> argparse.Namespace:
|
|||||||
default=20,
|
default=20,
|
||||||
help="Minimum positives for per-cause AUC",
|
help="Minimum positives for per-cause AUC",
|
||||||
)
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--no_tqdm",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable tqdm progress bars",
|
||||||
|
)
|
||||||
return p.parse_args()
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -60,6 +70,8 @@ def main() -> None:
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
seed_everything(args.seed)
|
seed_everything(args.seed)
|
||||||
|
|
||||||
|
show_progress = (not args.no_tqdm)
|
||||||
|
|
||||||
run_dir = args.run_dir
|
run_dir = args.run_dir
|
||||||
cfg = load_train_config(run_dir)
|
cfg = load_train_config(run_dir)
|
||||||
|
|
||||||
@@ -72,6 +84,7 @@ def main() -> None:
|
|||||||
subset=test_subset,
|
subset=test_subset,
|
||||||
age_bins_years=age_bins_years,
|
age_bins_years=age_bins_years,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
|
show_progress=show_progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
@@ -91,7 +104,16 @@ def main() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
tau = float(args.tau_short)
|
tau = float(args.tau_short)
|
||||||
scores = predict_cifs(model, head, criterion, loader, [tau], device=device)
|
scores = predict_cifs(
|
||||||
|
model,
|
||||||
|
head,
|
||||||
|
criterion,
|
||||||
|
loader,
|
||||||
|
[tau],
|
||||||
|
device=device,
|
||||||
|
show_progress=show_progress,
|
||||||
|
progress_desc="Inference (next-event)",
|
||||||
|
)
|
||||||
# scores shape: (N,K,1) for multi-taus; squeeze last
|
# scores shape: (N,K,1) for multi-taus; squeeze last
|
||||||
if scores.ndim == 3:
|
if scores.ndim == 3:
|
||||||
scores = scores[:, :, 0]
|
scores = scores[:, :, 0]
|
||||||
@@ -167,7 +189,7 @@ def main() -> None:
|
|||||||
|
|
||||||
aucs = auc[np.isfinite(auc)]
|
aucs = auc[np.isfinite(auc)]
|
||||||
|
|
||||||
if aucs:
|
if aucs.size > 0:
|
||||||
metrics_rows.append(
|
metrics_rows.append(
|
||||||
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
|
{"metric": "macro_ovr_auc", "value": float(np.mean(aucs))})
|
||||||
else:
|
else:
|
||||||
|
|||||||
28
utils.py
28
utils.py
@@ -10,6 +10,11 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, Dataset, Subset, random_split
|
from torch.utils.data import DataLoader, Dataset, Subset, random_split
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm as _tqdm
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
_tqdm = None
|
||||||
|
|
||||||
from dataset import HealthDataset
|
from dataset import HealthDataset
|
||||||
from losses import (
|
from losses import (
|
||||||
DiscreteTimeCIFNLLLoss,
|
DiscreteTimeCIFNLLLoss,
|
||||||
@@ -23,6 +28,12 @@ DAYS_PER_YEAR = 365.25
|
|||||||
N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2
|
N_TECH_TOKENS = 2 # pad=0, DOA=1, diseases start at 2
|
||||||
|
|
||||||
|
|
||||||
|
def _progress(iterable, *, enabled: bool, desc: str, total: Optional[int] = None):
|
||||||
|
if enabled and _tqdm is not None:
|
||||||
|
return _tqdm(iterable, desc=desc, total=total)
|
||||||
|
return iterable
|
||||||
|
|
||||||
|
|
||||||
def make_inference_dataloader_kwargs(
|
def make_inference_dataloader_kwargs(
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
num_workers: int,
|
num_workers: int,
|
||||||
@@ -278,6 +289,7 @@ def build_event_driven_records(
|
|||||||
subset: Subset,
|
subset: Subset,
|
||||||
age_bins_years: Sequence[float],
|
age_bins_years: Sequence[float],
|
||||||
seed: int,
|
seed: int,
|
||||||
|
show_progress: bool = False,
|
||||||
) -> List[EvalRecord]:
|
) -> List[EvalRecord]:
|
||||||
if len(age_bins_years) < 2:
|
if len(age_bins_years) < 2:
|
||||||
raise ValueError("age_bins must have at least 2 boundaries")
|
raise ValueError("age_bins must have at least 2 boundaries")
|
||||||
@@ -296,7 +308,12 @@ def build_event_driven_records(
|
|||||||
# Speed: avoid calling dataset.__getitem__ for every patient here.
|
# Speed: avoid calling dataset.__getitem__ for every patient here.
|
||||||
# We only need DOA + event times/codes to create evaluation records.
|
# We only need DOA + event times/codes to create evaluation records.
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
for patient_idx in indices:
|
for patient_idx in _progress(
|
||||||
|
indices,
|
||||||
|
enabled=show_progress,
|
||||||
|
desc="Building eval records",
|
||||||
|
total=len(indices),
|
||||||
|
):
|
||||||
patient_id = dataset.patient_ids[patient_idx]
|
patient_id = dataset.patient_ids[patient_idx]
|
||||||
|
|
||||||
doa_days = float(dataset._doa[patient_idx])
|
doa_days = float(dataset._doa[patient_idx])
|
||||||
@@ -445,6 +462,8 @@ def predict_cifs(
|
|||||||
loader: DataLoader,
|
loader: DataLoader,
|
||||||
taus_years: Sequence[float],
|
taus_years: Sequence[float],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
show_progress: bool = False,
|
||||||
|
progress_desc: str = "Inference",
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
model.eval()
|
model.eval()
|
||||||
head.eval()
|
head.eval()
|
||||||
@@ -453,7 +472,12 @@ def predict_cifs(
|
|||||||
|
|
||||||
all_out: List[np.ndarray] = []
|
all_out: List[np.ndarray] = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in loader:
|
for batch in _progress(
|
||||||
|
loader,
|
||||||
|
enabled=show_progress,
|
||||||
|
desc=progress_desc,
|
||||||
|
total=len(loader) if hasattr(loader, "__len__") else None,
|
||||||
|
):
|
||||||
event_seq, time_seq, cont, cate, sex, baseline_pos = batch
|
event_seq, time_seq, cont, cate, sex, baseline_pos = batch
|
||||||
event_seq = event_seq.to(device, non_blocking=True)
|
event_seq = event_seq.to(device, non_blocking=True)
|
||||||
time_seq = time_seq.to(device, non_blocking=True)
|
time_seq = time_seq.to(device, non_blocking=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user