Add --max_cpu_cores argument for parallel processing in evaluation scripts

This commit is contained in:
2026-01-17 23:53:24 +08:00
parent 248fb09c34
commit cfe7f88162
4 changed files with 234 additions and 27 deletions

View File

@@ -77,6 +77,12 @@ def parse_args() -> argparse.Namespace:
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument(
"--max_cpu_cores",
type=int,
default=-1,
help="Maximum number of CPU cores to use for parallel data construction.",
)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--min_pos", type=int, default=20)
p.add_argument(
@@ -171,6 +177,7 @@ def main() -> None:
age_bins_years=age_bins_years,
seed=args.seed,
show_progress=show_progress,
n_jobs=int(args.max_cpu_cores),
)
device = torch.device(args.device)

View File

@@ -50,6 +50,12 @@ def parse_args() -> argparse.Namespace:
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument(
"--max_cpu_cores",
type=int,
default=-1,
help="Maximum number of CPU cores to use for parallel data construction.",
)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--min_pos",
@@ -96,6 +102,45 @@ def _compute_next_event_auc_clean_control(
dtype=np.int64,
)
# Pre-compute lifetime disease membership matrix for vectorized clean-control filtering.
# lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes.
# Use a sparse matrix when SciPy is available to keep memory bounded for large K.
row_parts: List[np.ndarray] = []
col_parts: List[np.ndarray] = []
for i, r in enumerate(records):
causes = getattr(r, "lifetime_causes", None)
if causes is None:
continue
causes = np.asarray(causes, dtype=np.int64)
if causes.size == 0:
continue
# Keep only valid cause ids.
m_valid = (causes >= 0) & (causes < K)
if not np.any(m_valid):
continue
causes = causes[m_valid]
row_parts.append(np.full((causes.size,), i, dtype=np.int32))
col_parts.append(causes.astype(np.int32, copy=False))
try:
import scipy.sparse as sp # type: ignore
if row_parts:
rows = np.concatenate(row_parts, axis=0)
cols = np.concatenate(col_parts, axis=0)
data = np.ones((rows.size,), dtype=bool)
lifetime_matrix = sp.csc_matrix(
(data, (rows, cols)), shape=(n_records, K))
else:
lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool)
lifetime_is_sparse = True
except Exception: # pragma: no cover
lifetime_matrix = np.zeros((n_records, K), dtype=bool)
for rows, cols in zip(row_parts, col_parts):
lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype(
np.int64, copy=False)] = True
lifetime_is_sparse = False
auc = np.full((K,), np.nan, dtype=np.float64)
var = np.full((K,), np.nan, dtype=np.float64)
n_case = np.zeros((K,), dtype=np.int64)
@@ -109,12 +154,13 @@ def _compute_next_event_auc_clean_control(
# Clean controls: not next-event k AND never had k in their lifetime history.
control_mask = y_next != k
if np.any(control_mask):
clean = np.fromiter(
((k not in rec.lifetime_causes) for rec in records),
dtype=bool,
count=n_records,
)
control_mask = control_mask & clean
if lifetime_is_sparse:
had_k = np.asarray(lifetime_matrix.getcol(
k).toarray().ravel(), dtype=bool)
else:
had_k = lifetime_matrix[:, k]
is_clean = ~had_k
control_mask = control_mask & is_clean
cs = scores[case_mask, k]
hs = scores[control_mask, k]
@@ -156,6 +202,7 @@ def main() -> None:
age_bins_years=age_bins_years,
seed=args.seed,
show_progress=show_progress,
n_jobs=int(args.max_cpu_cores),
)
device = torch.device(args.device)

View File

@@ -31,7 +31,7 @@ Options:
Common eval args:
Anything after `--` is appended to BOTH evaluation commands.
Use this only for flags supported by BOTH scripts (e.g. --batch_size, --num_workers, --seed, --min_pos, --no_tqdm).
Use this only for flags supported by BOTH scripts (e.g. --batch_size, --num_workers, --max_cpu_cores, --seed, --min_pos, --no_tqdm).
Per-eval args:
For eval-specific flags (e.g. evaluate_horizon.py --topk_list / --workload_fracs), use --horizon-args-file.
@@ -41,6 +41,8 @@ Examples:
./run_evaluations_multi_gpu.sh --gpus 0,1
./run_evaluations_multi_gpu.sh --gpus 0,1 --runs-root runs --pattern "delphi_*" \
--horizons 0.25 0.5 1 2 5 10 --age-bins 40 45 50 55 60 65 70 75 inf -- --batch_size 512 --num_workers 4
./run_evaluations_multi_gpu.sh --gpus 0,1 --runs-root runs --pattern "delphi_*" \
-- --batch_size 512 --num_workers 4 --max_cpu_cores -1
USAGE
}

191
utils.py
View File

@@ -15,6 +15,12 @@ try:
except Exception: # pragma: no cover
_tqdm = None
try:
from joblib import Parallel, delayed # type: ignore
except Exception: # pragma: no cover
Parallel = None
delayed = None
from dataset import HealthDataset
from losses import (
DiscreteTimeCIFNLLLoss,
@@ -290,6 +296,9 @@ def build_event_driven_records(
age_bins_years: Sequence[float],
seed: int,
show_progress: bool = False,
n_jobs: int = 1,
chunk_size: int = 256,
prefer: str = "threads",
) -> List[EvalRecord]:
if len(age_bins_years) < 2:
raise ValueError("age_bins must have at least 2 boundaries")
@@ -298,20 +307,20 @@ def build_event_driven_records(
if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)):
raise ValueError("age_bins must be strictly increasing")
rng = np.random.default_rng(seed)
def _iter_chunks(n: int, size: int) -> List[np.ndarray]:
if size <= 0:
raise ValueError("chunk_size must be >= 1")
if n == 0:
return []
idx = np.arange(n, dtype=np.int64)
return [idx[i:i + size] for i in range(0, n, size)]
records: List[EvalRecord] = []
# Build records exclusively from the provided subset.
# We intentionally avoid reading from subset.dataset internals so the
# evaluation pipeline does not depend on the full dataset object.
eps = 1e-6
for subset_idx in _progress(
range(len(subset)),
enabled=show_progress,
desc="Building eval records",
total=len(subset),
):
def _build_records_for_index(
subset_idx: int,
*,
age_bins_days_local: Sequence[float],
rng_local: np.random.Generator,
) -> List[EvalRecord]:
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
@@ -333,12 +342,17 @@ def build_event_driven_records(
lifetime_causes = np.zeros((0,), dtype=np.int64)
disease_pos_all = np.flatnonzero(is_disease)
disease_times_all = times_ins[disease_pos_all] if disease_pos_all.size > 0 else np.zeros(
(0,), dtype=np.float64)
disease_times_all = (
times_ins[disease_pos_all]
if disease_pos_all.size > 0
else np.zeros((0,), dtype=np.float64)
)
for b in range(len(age_bins_days) - 1):
lo = age_bins_days[b]
hi = age_bins_days[b + 1]
eps = 1e-6
out: List[EvalRecord] = []
for b in range(len(age_bins_days_local) - 1):
lo = float(age_bins_days_local[b])
hi = float(age_bins_days_local[b + 1])
# Inclusion rule:
# 1) DOA <= bin_upper
@@ -359,7 +373,7 @@ def build_event_driven_records(
if cand_pos.size == 0:
continue
cutoff_pos = int(rng.choice(cand_pos))
cutoff_pos = int(rng_local.choice(cand_pos))
t0_days = float(times_ins[cutoff_pos])
# Future disease events strictly after t0
@@ -376,13 +390,150 @@ def build_event_driven_records(
future_causes = (
future_tokens - N_TECH_TOKENS).astype(np.int64)
future_dt_years_arr = (
(future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32)
(future_times_days - t0_days) / DAYS_PER_YEAR
).astype(np.float32)
# next-event = minimal time > t0 (tie broken by earliest position)
next_idx = int(np.argmin(future_times_days))
next_cause = int(future_causes[next_idx])
next_dt_years = float(future_dt_years_arr[next_idx])
out.append(
EvalRecord(
subset_idx=int(subset_idx),
doa_days=float(doa_days),
t0_days=float(t0_days),
cutoff_pos=int(cutoff_pos),
next_event_cause=next_cause,
next_event_dt_years=next_dt_years,
lifetime_causes=lifetime_causes,
future_causes=future_causes,
future_dt_years=future_dt_years_arr,
)
)
return out
def _process_chunk(
chunk_indices: Sequence[int],
*,
age_bins_days_local: Sequence[float],
seed_local: int,
) -> List[EvalRecord]:
out: List[EvalRecord] = []
for subset_idx in chunk_indices:
# Ensure each subject has its own deterministic RNG stream, so parallel
# workers do not share identical seeds.
ss = np.random.SeedSequence([int(seed_local), int(subset_idx)])
rng_local = np.random.default_rng(ss)
out.extend(
_build_records_for_index(
int(subset_idx),
age_bins_days_local=age_bins_days_local,
rng_local=rng_local,
)
)
return out
n = int(len(subset))
chunks = _iter_chunks(n, int(chunk_size))
do_parallel = (
int(n_jobs) != 1
and Parallel is not None
and delayed is not None
and n > 0
)
if do_parallel:
# Note: on Windows, process-based parallelism may require the underlying
# dataset to be pickleable. `prefer="threads"` is the default for safety.
parts = Parallel(n_jobs=int(n_jobs), prefer=str(prefer), batch_size=1)(
delayed(_process_chunk)(
chunk,
age_bins_days_local=age_bins_days,
seed_local=int(seed),
)
for chunk in chunks
)
records = [r for part in parts for r in part]
return records
# Sequential (preserve prior behavior/progress reporting)
rng = np.random.default_rng(seed)
records: List[EvalRecord] = []
eps = 1e-6
for subset_idx in _progress(
range(len(subset)),
enabled=show_progress,
desc="Building eval records",
total=len(subset),
):
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
doa_pos = np.flatnonzero(codes_ins == 1)
if doa_pos.size == 0:
raise ValueError("Expected DOA token (code=1) in event sequence")
doa_days = float(times_ins[int(doa_pos[0])])
is_disease = codes_ins >= N_TECH_TOKENS
if np.any(is_disease):
lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype(
np.int64, copy=False
)
lifetime_causes = np.unique(lifetime_causes)
else:
lifetime_causes = np.zeros((0,), dtype=np.int64)
disease_pos_all = np.flatnonzero(is_disease)
disease_times_all = (
times_ins[disease_pos_all]
if disease_pos_all.size > 0
else np.zeros((0,), dtype=np.float64)
)
for b in range(len(age_bins_days) - 1):
lo = age_bins_days[b]
hi = age_bins_days[b + 1]
if not (doa_days <= hi):
continue
if disease_pos_all.size == 0:
continue
in_bin = (
(disease_times_all >= lo)
& (disease_times_all < hi)
& (disease_times_all >= doa_days)
)
cand_pos = disease_pos_all[in_bin]
if cand_pos.size == 0:
continue
cutoff_pos = int(rng.choice(cand_pos))
t0_days = float(times_ins[cutoff_pos])
future_mask = (times_ins > (t0_days + eps)) & is_disease
future_pos = np.flatnonzero(future_mask)
if future_pos.size == 0:
next_cause = None
next_dt_years = None
future_causes = np.zeros((0,), dtype=np.int64)
future_dt_years_arr = np.zeros((0,), dtype=np.float32)
else:
future_times_days = times_ins[future_pos]
future_tokens = codes_ins[future_pos]
future_causes = (
future_tokens - N_TECH_TOKENS).astype(np.int64)
future_dt_years_arr = (
(future_times_days - t0_days) / DAYS_PER_YEAR
).astype(np.float32)
next_idx = int(np.argmin(future_times_days))
next_cause = int(future_causes[next_idx])
next_dt_years = float(future_dt_years_arr[next_idx])
records.append(
EvalRecord(
subset_idx=int(subset_idx),