Add --max_cpu_cores argument for parallel processing in evaluation scripts

This commit is contained in:
2026-01-17 23:53:24 +08:00
parent 248fb09c34
commit cfe7f88162
4 changed files with 234 additions and 27 deletions

191
utils.py
View File

@@ -15,6 +15,12 @@ try:
except Exception: # pragma: no cover
_tqdm = None
try:
from joblib import Parallel, delayed # type: ignore
except Exception: # pragma: no cover
Parallel = None
delayed = None
from dataset import HealthDataset
from losses import (
DiscreteTimeCIFNLLLoss,
@@ -290,6 +296,9 @@ def build_event_driven_records(
age_bins_years: Sequence[float],
seed: int,
show_progress: bool = False,
n_jobs: int = 1,
chunk_size: int = 256,
prefer: str = "threads",
) -> List[EvalRecord]:
if len(age_bins_years) < 2:
raise ValueError("age_bins must have at least 2 boundaries")
@@ -298,20 +307,20 @@ def build_event_driven_records(
if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)):
raise ValueError("age_bins must be strictly increasing")
rng = np.random.default_rng(seed)
def _iter_chunks(n: int, size: int) -> List[np.ndarray]:
if size <= 0:
raise ValueError("chunk_size must be >= 1")
if n == 0:
return []
idx = np.arange(n, dtype=np.int64)
return [idx[i:i + size] for i in range(0, n, size)]
records: List[EvalRecord] = []
# Build records exclusively from the provided subset.
# We intentionally avoid reading from subset.dataset internals so the
# evaluation pipeline does not depend on the full dataset object.
eps = 1e-6
for subset_idx in _progress(
range(len(subset)),
enabled=show_progress,
desc="Building eval records",
total=len(subset),
):
def _build_records_for_index(
subset_idx: int,
*,
age_bins_days_local: Sequence[float],
rng_local: np.random.Generator,
) -> List[EvalRecord]:
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
@@ -333,12 +342,17 @@ def build_event_driven_records(
lifetime_causes = np.zeros((0,), dtype=np.int64)
disease_pos_all = np.flatnonzero(is_disease)
disease_times_all = times_ins[disease_pos_all] if disease_pos_all.size > 0 else np.zeros(
(0,), dtype=np.float64)
disease_times_all = (
times_ins[disease_pos_all]
if disease_pos_all.size > 0
else np.zeros((0,), dtype=np.float64)
)
for b in range(len(age_bins_days) - 1):
lo = age_bins_days[b]
hi = age_bins_days[b + 1]
eps = 1e-6
out: List[EvalRecord] = []
for b in range(len(age_bins_days_local) - 1):
lo = float(age_bins_days_local[b])
hi = float(age_bins_days_local[b + 1])
# Inclusion rule:
# 1) DOA <= bin_upper
@@ -359,7 +373,7 @@ def build_event_driven_records(
if cand_pos.size == 0:
continue
cutoff_pos = int(rng.choice(cand_pos))
cutoff_pos = int(rng_local.choice(cand_pos))
t0_days = float(times_ins[cutoff_pos])
# Future disease events strictly after t0
@@ -376,13 +390,150 @@ def build_event_driven_records(
future_causes = (
future_tokens - N_TECH_TOKENS).astype(np.int64)
future_dt_years_arr = (
(future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32)
(future_times_days - t0_days) / DAYS_PER_YEAR
).astype(np.float32)
# next-event = minimal time > t0 (tie broken by earliest position)
next_idx = int(np.argmin(future_times_days))
next_cause = int(future_causes[next_idx])
next_dt_years = float(future_dt_years_arr[next_idx])
out.append(
EvalRecord(
subset_idx=int(subset_idx),
doa_days=float(doa_days),
t0_days=float(t0_days),
cutoff_pos=int(cutoff_pos),
next_event_cause=next_cause,
next_event_dt_years=next_dt_years,
lifetime_causes=lifetime_causes,
future_causes=future_causes,
future_dt_years=future_dt_years_arr,
)
)
return out
def _process_chunk(
chunk_indices: Sequence[int],
*,
age_bins_days_local: Sequence[float],
seed_local: int,
) -> List[EvalRecord]:
out: List[EvalRecord] = []
for subset_idx in chunk_indices:
# Ensure each subject has its own deterministic RNG stream, so parallel
# workers do not share identical seeds.
ss = np.random.SeedSequence([int(seed_local), int(subset_idx)])
rng_local = np.random.default_rng(ss)
out.extend(
_build_records_for_index(
int(subset_idx),
age_bins_days_local=age_bins_days_local,
rng_local=rng_local,
)
)
return out
n = int(len(subset))
chunks = _iter_chunks(n, int(chunk_size))
do_parallel = (
int(n_jobs) != 1
and Parallel is not None
and delayed is not None
and n > 0
)
if do_parallel:
# Note: on Windows, process-based parallelism may require the underlying
# dataset to be pickleable. `prefer="threads"` is the default for safety.
parts = Parallel(n_jobs=int(n_jobs), prefer=str(prefer), batch_size=1)(
delayed(_process_chunk)(
chunk,
age_bins_days_local=age_bins_days,
seed_local=int(seed),
)
for chunk in chunks
)
records = [r for part in parts for r in part]
return records
# Sequential (preserve prior behavior/progress reporting)
rng = np.random.default_rng(seed)
records: List[EvalRecord] = []
eps = 1e-6
for subset_idx in _progress(
range(len(subset)),
enabled=show_progress,
desc="Building eval records",
total=len(subset),
):
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
doa_pos = np.flatnonzero(codes_ins == 1)
if doa_pos.size == 0:
raise ValueError("Expected DOA token (code=1) in event sequence")
doa_days = float(times_ins[int(doa_pos[0])])
is_disease = codes_ins >= N_TECH_TOKENS
if np.any(is_disease):
lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype(
np.int64, copy=False
)
lifetime_causes = np.unique(lifetime_causes)
else:
lifetime_causes = np.zeros((0,), dtype=np.int64)
disease_pos_all = np.flatnonzero(is_disease)
disease_times_all = (
times_ins[disease_pos_all]
if disease_pos_all.size > 0
else np.zeros((0,), dtype=np.float64)
)
for b in range(len(age_bins_days) - 1):
lo = age_bins_days[b]
hi = age_bins_days[b + 1]
if not (doa_days <= hi):
continue
if disease_pos_all.size == 0:
continue
in_bin = (
(disease_times_all >= lo)
& (disease_times_all < hi)
& (disease_times_all >= doa_days)
)
cand_pos = disease_pos_all[in_bin]
if cand_pos.size == 0:
continue
cutoff_pos = int(rng.choice(cand_pos))
t0_days = float(times_ins[cutoff_pos])
future_mask = (times_ins > (t0_days + eps)) & is_disease
future_pos = np.flatnonzero(future_mask)
if future_pos.size == 0:
next_cause = None
next_dt_years = None
future_causes = np.zeros((0,), dtype=np.int64)
future_dt_years_arr = np.zeros((0,), dtype=np.float32)
else:
future_times_days = times_ins[future_pos]
future_tokens = codes_ins[future_pos]
future_causes = (
future_tokens - N_TECH_TOKENS).astype(np.int64)
future_dt_years_arr = (
(future_times_days - t0_days) / DAYS_PER_YEAR
).astype(np.float32)
next_idx = int(np.argmin(future_times_days))
next_cause = int(future_causes[next_idx])
next_dt_years = float(future_dt_years_arr[next_idx])
records.append(
EvalRecord(
subset_idx=int(subset_idx),