diff --git a/evaluate_next_event.py b/evaluate_next_event.py index d8c8a89..7858366 100644 --- a/evaluate_next_event.py +++ b/evaluate_next_event.py @@ -186,6 +186,22 @@ def _compute_next_event_auc_clean_control( def main() -> None: args = parse_args() + + # Best-effort control of implicit parallelism to avoid CPU oversubscription. + # Note: environment variables are ideally set before importing NumPy/PyTorch, + # but setting them early in main can still affect subprocesses or lazy readers. + if int(args.max_cpu_cores) > 0: + n_threads = int(args.max_cpu_cores) + torch.set_num_threads(n_threads) + for k in ( + "OMP_NUM_THREADS", + "MKL_NUM_THREADS", + "OPENBLAS_NUM_THREADS", + "VECLIB_MAXIMUM_THREADS", + "NUMEXPR_NUM_THREADS", + ): + os.environ[k] = str(n_threads) + print(f"Restricting implicit parallelism to {n_threads} threads.") seed_everything(args.seed) show_progress = (not args.no_tqdm)