From cfe7f88162e2afb5c9959abfa8e9081fd61c137b Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sat, 17 Jan 2026 23:53:24 +0800 Subject: [PATCH] Add --max_cpu_cores argument for parallel processing in evaluation scripts --- evaluate_horizon.py | 7 ++ evaluate_next_event.py | 59 +++++++++-- run_evaluations_multi_gpu.sh | 4 +- utils.py | 191 +++++++++++++++++++++++++++++++---- 4 files changed, 234 insertions(+), 27 deletions(-) diff --git a/evaluate_horizon.py b/evaluate_horizon.py index 9c9257c..a288f9a 100644 --- a/evaluate_horizon.py +++ b/evaluate_horizon.py @@ -77,6 +77,12 @@ def parse_args() -> argparse.Namespace: ) p.add_argument("--batch_size", type=int, default=256) p.add_argument("--num_workers", type=int, default=0) + p.add_argument( + "--max_cpu_cores", + type=int, + default=-1, + help="Maximum number of CPU cores to use for parallel data construction.", + ) p.add_argument("--seed", type=int, default=0) p.add_argument("--min_pos", type=int, default=20) p.add_argument( @@ -171,6 +177,7 @@ def main() -> None: age_bins_years=age_bins_years, seed=args.seed, show_progress=show_progress, + n_jobs=int(args.max_cpu_cores), ) device = torch.device(args.device) diff --git a/evaluate_next_event.py b/evaluate_next_event.py index b5b0e61..d8c8a89 100644 --- a/evaluate_next_event.py +++ b/evaluate_next_event.py @@ -50,6 +50,12 @@ def parse_args() -> argparse.Namespace: ) p.add_argument("--batch_size", type=int, default=256) p.add_argument("--num_workers", type=int, default=0) + p.add_argument( + "--max_cpu_cores", + type=int, + default=-1, + help="Maximum number of CPU cores to use for parallel data construction.", + ) p.add_argument("--seed", type=int, default=0) p.add_argument( "--min_pos", @@ -96,6 +102,45 @@ def _compute_next_event_auc_clean_control( dtype=np.int64, ) + # Pre-compute lifetime disease membership matrix for vectorized clean-control filtering. + # lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes. + # Use a sparse matrix when SciPy is available to keep memory bounded for large K. + row_parts: List[np.ndarray] = [] + col_parts: List[np.ndarray] = [] + for i, r in enumerate(records): + causes = getattr(r, "lifetime_causes", None) + if causes is None: + continue + causes = np.asarray(causes, dtype=np.int64) + if causes.size == 0: + continue + # Keep only valid cause ids. + m_valid = (causes >= 0) & (causes < K) + if not np.any(m_valid): + continue + causes = causes[m_valid] + row_parts.append(np.full((causes.size,), i, dtype=np.int32)) + col_parts.append(causes.astype(np.int32, copy=False)) + + try: + import scipy.sparse as sp # type: ignore + + if row_parts: + rows = np.concatenate(row_parts, axis=0) + cols = np.concatenate(col_parts, axis=0) + data = np.ones((rows.size,), dtype=bool) + lifetime_matrix = sp.csc_matrix( + (data, (rows, cols)), shape=(n_records, K)) + else: + lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool) + lifetime_is_sparse = True + except Exception: # pragma: no cover + lifetime_matrix = np.zeros((n_records, K), dtype=bool) + for rows, cols in zip(row_parts, col_parts): + lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype( + np.int64, copy=False)] = True + lifetime_is_sparse = False + auc = np.full((K,), np.nan, dtype=np.float64) var = np.full((K,), np.nan, dtype=np.float64) n_case = np.zeros((K,), dtype=np.int64) @@ -109,12 +154,13 @@ def _compute_next_event_auc_clean_control( # Clean controls: not next-event k AND never had k in their lifetime history. control_mask = y_next != k if np.any(control_mask): - clean = np.fromiter( - ((k not in rec.lifetime_causes) for rec in records), - dtype=bool, - count=n_records, - ) - control_mask = control_mask & clean + if lifetime_is_sparse: + had_k = np.asarray(lifetime_matrix.getcol( + k).toarray().ravel(), dtype=bool) + else: + had_k = lifetime_matrix[:, k] + is_clean = ~had_k + control_mask = control_mask & is_clean cs = scores[case_mask, k] hs = scores[control_mask, k] @@ -156,6 +202,7 @@ def main() -> None: age_bins_years=age_bins_years, seed=args.seed, show_progress=show_progress, + n_jobs=int(args.max_cpu_cores), ) device = torch.device(args.device) diff --git a/run_evaluations_multi_gpu.sh b/run_evaluations_multi_gpu.sh index 9bb498a..a60f00b 100644 --- a/run_evaluations_multi_gpu.sh +++ b/run_evaluations_multi_gpu.sh @@ -31,7 +31,7 @@ Options: Common eval args: Anything after `--` is appended to BOTH evaluation commands. - Use this only for flags supported by BOTH scripts (e.g. --batch_size, --num_workers, --seed, --min_pos, --no_tqdm). + Use this only for flags supported by BOTH scripts (e.g. --batch_size, --num_workers, --max_cpu_cores, --seed, --min_pos, --no_tqdm). Per-eval args: For eval-specific flags (e.g. evaluate_horizon.py --topk_list / --workload_fracs), use --horizon-args-file. @@ -41,6 +41,8 @@ Examples: ./run_evaluations_multi_gpu.sh --gpus 0,1 ./run_evaluations_multi_gpu.sh --gpus 0,1 --runs-root runs --pattern "delphi_*" \ --horizons 0.25 0.5 1 2 5 10 --age-bins 40 45 50 55 60 65 70 75 inf -- --batch_size 512 --num_workers 4 + ./run_evaluations_multi_gpu.sh --gpus 0,1 --runs-root runs --pattern "delphi_*" \ + -- --batch_size 512 --num_workers 4 --max_cpu_cores -1 USAGE } diff --git a/utils.py b/utils.py index 2ed23c9..d634382 100644 --- a/utils.py +++ b/utils.py @@ -15,6 +15,12 @@ try: except Exception: # pragma: no cover _tqdm = None +try: + from joblib import Parallel, delayed # type: ignore +except Exception: # pragma: no cover + Parallel = None + delayed = None + from dataset import HealthDataset from losses import ( DiscreteTimeCIFNLLLoss, @@ -290,6 +296,9 @@ def build_event_driven_records( age_bins_years: Sequence[float], seed: int, show_progress: bool = False, + n_jobs: int = 1, + chunk_size: int = 256, + prefer: str = "threads", ) -> List[EvalRecord]: if len(age_bins_years) < 2: raise ValueError("age_bins must have at least 2 boundaries") @@ -298,20 +307,20 @@ def build_event_driven_records( if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)): raise ValueError("age_bins must be strictly increasing") - rng = np.random.default_rng(seed) + def _iter_chunks(n: int, size: int) -> List[np.ndarray]: + if size <= 0: + raise ValueError("chunk_size must be >= 1") + if n == 0: + return [] + idx = np.arange(n, dtype=np.int64) + return [idx[i:i + size] for i in range(0, n, size)] - records: List[EvalRecord] = [] - - # Build records exclusively from the provided subset. - # We intentionally avoid reading from subset.dataset internals so the - # evaluation pipeline does not depend on the full dataset object. - eps = 1e-6 - for subset_idx in _progress( - range(len(subset)), - enabled=show_progress, - desc="Building eval records", - total=len(subset), - ): + def _build_records_for_index( + subset_idx: int, + *, + age_bins_days_local: Sequence[float], + rng_local: np.random.Generator, + ) -> List[EvalRecord]: event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)] codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False) times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False) @@ -333,12 +342,17 @@ def build_event_driven_records( lifetime_causes = np.zeros((0,), dtype=np.int64) disease_pos_all = np.flatnonzero(is_disease) - disease_times_all = times_ins[disease_pos_all] if disease_pos_all.size > 0 else np.zeros( - (0,), dtype=np.float64) + disease_times_all = ( + times_ins[disease_pos_all] + if disease_pos_all.size > 0 + else np.zeros((0,), dtype=np.float64) + ) - for b in range(len(age_bins_days) - 1): - lo = age_bins_days[b] - hi = age_bins_days[b + 1] + eps = 1e-6 + out: List[EvalRecord] = [] + for b in range(len(age_bins_days_local) - 1): + lo = float(age_bins_days_local[b]) + hi = float(age_bins_days_local[b + 1]) # Inclusion rule: # 1) DOA <= bin_upper @@ -359,7 +373,7 @@ def build_event_driven_records( if cand_pos.size == 0: continue - cutoff_pos = int(rng.choice(cand_pos)) + cutoff_pos = int(rng_local.choice(cand_pos)) t0_days = float(times_ins[cutoff_pos]) # Future disease events strictly after t0 @@ -376,13 +390,150 @@ def build_event_driven_records( future_causes = ( future_tokens - N_TECH_TOKENS).astype(np.int64) future_dt_years_arr = ( - (future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32) + (future_times_days - t0_days) / DAYS_PER_YEAR + ).astype(np.float32) # next-event = minimal time > t0 (tie broken by earliest position) next_idx = int(np.argmin(future_times_days)) next_cause = int(future_causes[next_idx]) next_dt_years = float(future_dt_years_arr[next_idx]) + out.append( + EvalRecord( + subset_idx=int(subset_idx), + doa_days=float(doa_days), + t0_days=float(t0_days), + cutoff_pos=int(cutoff_pos), + next_event_cause=next_cause, + next_event_dt_years=next_dt_years, + lifetime_causes=lifetime_causes, + future_causes=future_causes, + future_dt_years=future_dt_years_arr, + ) + ) + return out + + def _process_chunk( + chunk_indices: Sequence[int], + *, + age_bins_days_local: Sequence[float], + seed_local: int, + ) -> List[EvalRecord]: + out: List[EvalRecord] = [] + for subset_idx in chunk_indices: + # Ensure each subject has its own deterministic RNG stream, so parallel + # workers do not share identical seeds. + ss = np.random.SeedSequence([int(seed_local), int(subset_idx)]) + rng_local = np.random.default_rng(ss) + out.extend( + _build_records_for_index( + int(subset_idx), + age_bins_days_local=age_bins_days_local, + rng_local=rng_local, + ) + ) + return out + + n = int(len(subset)) + chunks = _iter_chunks(n, int(chunk_size)) + + do_parallel = ( + int(n_jobs) != 1 + and Parallel is not None + and delayed is not None + and n > 0 + ) + + if do_parallel: + # Note: on Windows, process-based parallelism may require the underlying + # dataset to be pickleable. `prefer="threads"` is the default for safety. + parts = Parallel(n_jobs=int(n_jobs), prefer=str(prefer), batch_size=1)( + delayed(_process_chunk)( + chunk, + age_bins_days_local=age_bins_days, + seed_local=int(seed), + ) + for chunk in chunks + ) + records = [r for part in parts for r in part] + return records + + # Sequential (preserve prior behavior/progress reporting) + rng = np.random.default_rng(seed) + records: List[EvalRecord] = [] + eps = 1e-6 + for subset_idx in _progress( + range(len(subset)), + enabled=show_progress, + desc="Building eval records", + total=len(subset), + ): + event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)] + codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False) + times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False) + + doa_pos = np.flatnonzero(codes_ins == 1) + if doa_pos.size == 0: + raise ValueError("Expected DOA token (code=1) in event sequence") + doa_days = float(times_ins[int(doa_pos[0])]) + + is_disease = codes_ins >= N_TECH_TOKENS + + if np.any(is_disease): + lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype( + np.int64, copy=False + ) + lifetime_causes = np.unique(lifetime_causes) + else: + lifetime_causes = np.zeros((0,), dtype=np.int64) + + disease_pos_all = np.flatnonzero(is_disease) + disease_times_all = ( + times_ins[disease_pos_all] + if disease_pos_all.size > 0 + else np.zeros((0,), dtype=np.float64) + ) + + for b in range(len(age_bins_days) - 1): + lo = age_bins_days[b] + hi = age_bins_days[b + 1] + if not (doa_days <= hi): + continue + if disease_pos_all.size == 0: + continue + + in_bin = ( + (disease_times_all >= lo) + & (disease_times_all < hi) + & (disease_times_all >= doa_days) + ) + cand_pos = disease_pos_all[in_bin] + if cand_pos.size == 0: + continue + + cutoff_pos = int(rng.choice(cand_pos)) + t0_days = float(times_ins[cutoff_pos]) + + future_mask = (times_ins > (t0_days + eps)) & is_disease + future_pos = np.flatnonzero(future_mask) + if future_pos.size == 0: + next_cause = None + next_dt_years = None + future_causes = np.zeros((0,), dtype=np.int64) + future_dt_years_arr = np.zeros((0,), dtype=np.float32) + else: + future_times_days = times_ins[future_pos] + future_tokens = codes_ins[future_pos] + future_causes = ( + future_tokens - N_TECH_TOKENS).astype(np.int64) + future_dt_years_arr = ( + (future_times_days - t0_days) / DAYS_PER_YEAR + ).astype(np.float32) + + next_idx = int(np.argmin(future_times_days)) + next_cause = int(future_causes[next_idx]) + next_dt_years = float(future_dt_years_arr[next_idx]) + records.append( EvalRecord( subset_idx=int(subset_idx),