Add --max_cpu_cores argument for parallel processing in evaluation scripts
This commit is contained in:
191
utils.py
191
utils.py
@@ -15,6 +15,12 @@ try:
|
||||
except Exception: # pragma: no cover
|
||||
_tqdm = None
|
||||
|
||||
try:
|
||||
from joblib import Parallel, delayed # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
Parallel = None
|
||||
delayed = None
|
||||
|
||||
from dataset import HealthDataset
|
||||
from losses import (
|
||||
DiscreteTimeCIFNLLLoss,
|
||||
@@ -290,6 +296,9 @@ def build_event_driven_records(
|
||||
age_bins_years: Sequence[float],
|
||||
seed: int,
|
||||
show_progress: bool = False,
|
||||
n_jobs: int = 1,
|
||||
chunk_size: int = 256,
|
||||
prefer: str = "threads",
|
||||
) -> List[EvalRecord]:
|
||||
if len(age_bins_years) < 2:
|
||||
raise ValueError("age_bins must have at least 2 boundaries")
|
||||
@@ -298,20 +307,20 @@ def build_event_driven_records(
|
||||
if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)):
|
||||
raise ValueError("age_bins must be strictly increasing")
|
||||
|
||||
rng = np.random.default_rng(seed)
|
||||
def _iter_chunks(n: int, size: int) -> List[np.ndarray]:
|
||||
if size <= 0:
|
||||
raise ValueError("chunk_size must be >= 1")
|
||||
if n == 0:
|
||||
return []
|
||||
idx = np.arange(n, dtype=np.int64)
|
||||
return [idx[i:i + size] for i in range(0, n, size)]
|
||||
|
||||
records: List[EvalRecord] = []
|
||||
|
||||
# Build records exclusively from the provided subset.
|
||||
# We intentionally avoid reading from subset.dataset internals so the
|
||||
# evaluation pipeline does not depend on the full dataset object.
|
||||
eps = 1e-6
|
||||
for subset_idx in _progress(
|
||||
range(len(subset)),
|
||||
enabled=show_progress,
|
||||
desc="Building eval records",
|
||||
total=len(subset),
|
||||
):
|
||||
def _build_records_for_index(
|
||||
subset_idx: int,
|
||||
*,
|
||||
age_bins_days_local: Sequence[float],
|
||||
rng_local: np.random.Generator,
|
||||
) -> List[EvalRecord]:
|
||||
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
|
||||
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
|
||||
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
|
||||
@@ -333,12 +342,17 @@ def build_event_driven_records(
|
||||
lifetime_causes = np.zeros((0,), dtype=np.int64)
|
||||
|
||||
disease_pos_all = np.flatnonzero(is_disease)
|
||||
disease_times_all = times_ins[disease_pos_all] if disease_pos_all.size > 0 else np.zeros(
|
||||
(0,), dtype=np.float64)
|
||||
disease_times_all = (
|
||||
times_ins[disease_pos_all]
|
||||
if disease_pos_all.size > 0
|
||||
else np.zeros((0,), dtype=np.float64)
|
||||
)
|
||||
|
||||
for b in range(len(age_bins_days) - 1):
|
||||
lo = age_bins_days[b]
|
||||
hi = age_bins_days[b + 1]
|
||||
eps = 1e-6
|
||||
out: List[EvalRecord] = []
|
||||
for b in range(len(age_bins_days_local) - 1):
|
||||
lo = float(age_bins_days_local[b])
|
||||
hi = float(age_bins_days_local[b + 1])
|
||||
|
||||
# Inclusion rule:
|
||||
# 1) DOA <= bin_upper
|
||||
@@ -359,7 +373,7 @@ def build_event_driven_records(
|
||||
if cand_pos.size == 0:
|
||||
continue
|
||||
|
||||
cutoff_pos = int(rng.choice(cand_pos))
|
||||
cutoff_pos = int(rng_local.choice(cand_pos))
|
||||
t0_days = float(times_ins[cutoff_pos])
|
||||
|
||||
# Future disease events strictly after t0
|
||||
@@ -376,13 +390,150 @@ def build_event_driven_records(
|
||||
future_causes = (
|
||||
future_tokens - N_TECH_TOKENS).astype(np.int64)
|
||||
future_dt_years_arr = (
|
||||
(future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32)
|
||||
(future_times_days - t0_days) / DAYS_PER_YEAR
|
||||
).astype(np.float32)
|
||||
|
||||
# next-event = minimal time > t0 (tie broken by earliest position)
|
||||
next_idx = int(np.argmin(future_times_days))
|
||||
next_cause = int(future_causes[next_idx])
|
||||
next_dt_years = float(future_dt_years_arr[next_idx])
|
||||
|
||||
out.append(
|
||||
EvalRecord(
|
||||
subset_idx=int(subset_idx),
|
||||
doa_days=float(doa_days),
|
||||
t0_days=float(t0_days),
|
||||
cutoff_pos=int(cutoff_pos),
|
||||
next_event_cause=next_cause,
|
||||
next_event_dt_years=next_dt_years,
|
||||
lifetime_causes=lifetime_causes,
|
||||
future_causes=future_causes,
|
||||
future_dt_years=future_dt_years_arr,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
def _process_chunk(
|
||||
chunk_indices: Sequence[int],
|
||||
*,
|
||||
age_bins_days_local: Sequence[float],
|
||||
seed_local: int,
|
||||
) -> List[EvalRecord]:
|
||||
out: List[EvalRecord] = []
|
||||
for subset_idx in chunk_indices:
|
||||
# Ensure each subject has its own deterministic RNG stream, so parallel
|
||||
# workers do not share identical seeds.
|
||||
ss = np.random.SeedSequence([int(seed_local), int(subset_idx)])
|
||||
rng_local = np.random.default_rng(ss)
|
||||
out.extend(
|
||||
_build_records_for_index(
|
||||
int(subset_idx),
|
||||
age_bins_days_local=age_bins_days_local,
|
||||
rng_local=rng_local,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
n = int(len(subset))
|
||||
chunks = _iter_chunks(n, int(chunk_size))
|
||||
|
||||
do_parallel = (
|
||||
int(n_jobs) != 1
|
||||
and Parallel is not None
|
||||
and delayed is not None
|
||||
and n > 0
|
||||
)
|
||||
|
||||
if do_parallel:
|
||||
# Note: on Windows, process-based parallelism may require the underlying
|
||||
# dataset to be pickleable. `prefer="threads"` is the default for safety.
|
||||
parts = Parallel(n_jobs=int(n_jobs), prefer=str(prefer), batch_size=1)(
|
||||
delayed(_process_chunk)(
|
||||
chunk,
|
||||
age_bins_days_local=age_bins_days,
|
||||
seed_local=int(seed),
|
||||
)
|
||||
for chunk in chunks
|
||||
)
|
||||
records = [r for part in parts for r in part]
|
||||
return records
|
||||
|
||||
# Sequential (preserve prior behavior/progress reporting)
|
||||
rng = np.random.default_rng(seed)
|
||||
records: List[EvalRecord] = []
|
||||
eps = 1e-6
|
||||
for subset_idx in _progress(
|
||||
range(len(subset)),
|
||||
enabled=show_progress,
|
||||
desc="Building eval records",
|
||||
total=len(subset),
|
||||
):
|
||||
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
|
||||
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
|
||||
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
|
||||
|
||||
doa_pos = np.flatnonzero(codes_ins == 1)
|
||||
if doa_pos.size == 0:
|
||||
raise ValueError("Expected DOA token (code=1) in event sequence")
|
||||
doa_days = float(times_ins[int(doa_pos[0])])
|
||||
|
||||
is_disease = codes_ins >= N_TECH_TOKENS
|
||||
|
||||
if np.any(is_disease):
|
||||
lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype(
|
||||
np.int64, copy=False
|
||||
)
|
||||
lifetime_causes = np.unique(lifetime_causes)
|
||||
else:
|
||||
lifetime_causes = np.zeros((0,), dtype=np.int64)
|
||||
|
||||
disease_pos_all = np.flatnonzero(is_disease)
|
||||
disease_times_all = (
|
||||
times_ins[disease_pos_all]
|
||||
if disease_pos_all.size > 0
|
||||
else np.zeros((0,), dtype=np.float64)
|
||||
)
|
||||
|
||||
for b in range(len(age_bins_days) - 1):
|
||||
lo = age_bins_days[b]
|
||||
hi = age_bins_days[b + 1]
|
||||
if not (doa_days <= hi):
|
||||
continue
|
||||
if disease_pos_all.size == 0:
|
||||
continue
|
||||
|
||||
in_bin = (
|
||||
(disease_times_all >= lo)
|
||||
& (disease_times_all < hi)
|
||||
& (disease_times_all >= doa_days)
|
||||
)
|
||||
cand_pos = disease_pos_all[in_bin]
|
||||
if cand_pos.size == 0:
|
||||
continue
|
||||
|
||||
cutoff_pos = int(rng.choice(cand_pos))
|
||||
t0_days = float(times_ins[cutoff_pos])
|
||||
|
||||
future_mask = (times_ins > (t0_days + eps)) & is_disease
|
||||
future_pos = np.flatnonzero(future_mask)
|
||||
if future_pos.size == 0:
|
||||
next_cause = None
|
||||
next_dt_years = None
|
||||
future_causes = np.zeros((0,), dtype=np.int64)
|
||||
future_dt_years_arr = np.zeros((0,), dtype=np.float32)
|
||||
else:
|
||||
future_times_days = times_ins[future_pos]
|
||||
future_tokens = codes_ins[future_pos]
|
||||
future_causes = (
|
||||
future_tokens - N_TECH_TOKENS).astype(np.int64)
|
||||
future_dt_years_arr = (
|
||||
(future_times_days - t0_days) / DAYS_PER_YEAR
|
||||
).astype(np.float32)
|
||||
|
||||
next_idx = int(np.argmin(future_times_days))
|
||||
next_cause = int(future_causes[next_idx])
|
||||
next_dt_years = float(future_dt_years_arr[next_idx])
|
||||
|
||||
records.append(
|
||||
EvalRecord(
|
||||
subset_idx=int(subset_idx),
|
||||
|
||||
Reference in New Issue
Block a user