Add --max_cpu_cores argument for parallel processing in evaluation scripts
This commit is contained in:
@@ -77,6 +77,12 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
p.add_argument("--batch_size", type=int, default=256)
|
||||
p.add_argument("--num_workers", type=int, default=0)
|
||||
p.add_argument(
|
||||
"--max_cpu_cores",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Maximum number of CPU cores to use for parallel data construction.",
|
||||
)
|
||||
p.add_argument("--seed", type=int, default=0)
|
||||
p.add_argument("--min_pos", type=int, default=20)
|
||||
p.add_argument(
|
||||
@@ -171,6 +177,7 @@ def main() -> None:
|
||||
age_bins_years=age_bins_years,
|
||||
seed=args.seed,
|
||||
show_progress=show_progress,
|
||||
n_jobs=int(args.max_cpu_cores),
|
||||
)
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
@@ -50,6 +50,12 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
p.add_argument("--batch_size", type=int, default=256)
|
||||
p.add_argument("--num_workers", type=int, default=0)
|
||||
p.add_argument(
|
||||
"--max_cpu_cores",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Maximum number of CPU cores to use for parallel data construction.",
|
||||
)
|
||||
p.add_argument("--seed", type=int, default=0)
|
||||
p.add_argument(
|
||||
"--min_pos",
|
||||
@@ -96,6 +102,45 @@ def _compute_next_event_auc_clean_control(
|
||||
dtype=np.int64,
|
||||
)
|
||||
|
||||
# Pre-compute lifetime disease membership matrix for vectorized clean-control filtering.
|
||||
# lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes.
|
||||
# Use a sparse matrix when SciPy is available to keep memory bounded for large K.
|
||||
row_parts: List[np.ndarray] = []
|
||||
col_parts: List[np.ndarray] = []
|
||||
for i, r in enumerate(records):
|
||||
causes = getattr(r, "lifetime_causes", None)
|
||||
if causes is None:
|
||||
continue
|
||||
causes = np.asarray(causes, dtype=np.int64)
|
||||
if causes.size == 0:
|
||||
continue
|
||||
# Keep only valid cause ids.
|
||||
m_valid = (causes >= 0) & (causes < K)
|
||||
if not np.any(m_valid):
|
||||
continue
|
||||
causes = causes[m_valid]
|
||||
row_parts.append(np.full((causes.size,), i, dtype=np.int32))
|
||||
col_parts.append(causes.astype(np.int32, copy=False))
|
||||
|
||||
try:
|
||||
import scipy.sparse as sp # type: ignore
|
||||
|
||||
if row_parts:
|
||||
rows = np.concatenate(row_parts, axis=0)
|
||||
cols = np.concatenate(col_parts, axis=0)
|
||||
data = np.ones((rows.size,), dtype=bool)
|
||||
lifetime_matrix = sp.csc_matrix(
|
||||
(data, (rows, cols)), shape=(n_records, K))
|
||||
else:
|
||||
lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool)
|
||||
lifetime_is_sparse = True
|
||||
except Exception: # pragma: no cover
|
||||
lifetime_matrix = np.zeros((n_records, K), dtype=bool)
|
||||
for rows, cols in zip(row_parts, col_parts):
|
||||
lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype(
|
||||
np.int64, copy=False)] = True
|
||||
lifetime_is_sparse = False
|
||||
|
||||
auc = np.full((K,), np.nan, dtype=np.float64)
|
||||
var = np.full((K,), np.nan, dtype=np.float64)
|
||||
n_case = np.zeros((K,), dtype=np.int64)
|
||||
@@ -109,12 +154,13 @@ def _compute_next_event_auc_clean_control(
|
||||
# Clean controls: not next-event k AND never had k in their lifetime history.
|
||||
control_mask = y_next != k
|
||||
if np.any(control_mask):
|
||||
clean = np.fromiter(
|
||||
((k not in rec.lifetime_causes) for rec in records),
|
||||
dtype=bool,
|
||||
count=n_records,
|
||||
)
|
||||
control_mask = control_mask & clean
|
||||
if lifetime_is_sparse:
|
||||
had_k = np.asarray(lifetime_matrix.getcol(
|
||||
k).toarray().ravel(), dtype=bool)
|
||||
else:
|
||||
had_k = lifetime_matrix[:, k]
|
||||
is_clean = ~had_k
|
||||
control_mask = control_mask & is_clean
|
||||
|
||||
cs = scores[case_mask, k]
|
||||
hs = scores[control_mask, k]
|
||||
@@ -156,6 +202,7 @@ def main() -> None:
|
||||
age_bins_years=age_bins_years,
|
||||
seed=args.seed,
|
||||
show_progress=show_progress,
|
||||
n_jobs=int(args.max_cpu_cores),
|
||||
)
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
@@ -31,7 +31,7 @@ Options:
|
||||
|
||||
Common eval args:
|
||||
Anything after `--` is appended to BOTH evaluation commands.
|
||||
Use this only for flags supported by BOTH scripts (e.g. --batch_size, --num_workers, --seed, --min_pos, --no_tqdm).
|
||||
Use this only for flags supported by BOTH scripts (e.g. --batch_size, --num_workers, --max_cpu_cores, --seed, --min_pos, --no_tqdm).
|
||||
|
||||
Per-eval args:
|
||||
For eval-specific flags (e.g. evaluate_horizon.py --topk_list / --workload_fracs), use --horizon-args-file.
|
||||
@@ -41,6 +41,8 @@ Examples:
|
||||
./run_evaluations_multi_gpu.sh --gpus 0,1
|
||||
./run_evaluations_multi_gpu.sh --gpus 0,1 --runs-root runs --pattern "delphi_*" \
|
||||
--horizons 0.25 0.5 1 2 5 10 --age-bins 40 45 50 55 60 65 70 75 inf -- --batch_size 512 --num_workers 4
|
||||
./run_evaluations_multi_gpu.sh --gpus 0,1 --runs-root runs --pattern "delphi_*" \
|
||||
-- --batch_size 512 --num_workers 4 --max_cpu_cores -1
|
||||
|
||||
USAGE
|
||||
}
|
||||
|
||||
191
utils.py
191
utils.py
@@ -15,6 +15,12 @@ try:
|
||||
except Exception: # pragma: no cover
|
||||
_tqdm = None
|
||||
|
||||
try:
|
||||
from joblib import Parallel, delayed # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
Parallel = None
|
||||
delayed = None
|
||||
|
||||
from dataset import HealthDataset
|
||||
from losses import (
|
||||
DiscreteTimeCIFNLLLoss,
|
||||
@@ -290,6 +296,9 @@ def build_event_driven_records(
|
||||
age_bins_years: Sequence[float],
|
||||
seed: int,
|
||||
show_progress: bool = False,
|
||||
n_jobs: int = 1,
|
||||
chunk_size: int = 256,
|
||||
prefer: str = "threads",
|
||||
) -> List[EvalRecord]:
|
||||
if len(age_bins_years) < 2:
|
||||
raise ValueError("age_bins must have at least 2 boundaries")
|
||||
@@ -298,20 +307,20 @@ def build_event_driven_records(
|
||||
if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)):
|
||||
raise ValueError("age_bins must be strictly increasing")
|
||||
|
||||
rng = np.random.default_rng(seed)
|
||||
def _iter_chunks(n: int, size: int) -> List[np.ndarray]:
|
||||
if size <= 0:
|
||||
raise ValueError("chunk_size must be >= 1")
|
||||
if n == 0:
|
||||
return []
|
||||
idx = np.arange(n, dtype=np.int64)
|
||||
return [idx[i:i + size] for i in range(0, n, size)]
|
||||
|
||||
records: List[EvalRecord] = []
|
||||
|
||||
# Build records exclusively from the provided subset.
|
||||
# We intentionally avoid reading from subset.dataset internals so the
|
||||
# evaluation pipeline does not depend on the full dataset object.
|
||||
eps = 1e-6
|
||||
for subset_idx in _progress(
|
||||
range(len(subset)),
|
||||
enabled=show_progress,
|
||||
desc="Building eval records",
|
||||
total=len(subset),
|
||||
):
|
||||
def _build_records_for_index(
|
||||
subset_idx: int,
|
||||
*,
|
||||
age_bins_days_local: Sequence[float],
|
||||
rng_local: np.random.Generator,
|
||||
) -> List[EvalRecord]:
|
||||
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
|
||||
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
|
||||
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
|
||||
@@ -333,12 +342,17 @@ def build_event_driven_records(
|
||||
lifetime_causes = np.zeros((0,), dtype=np.int64)
|
||||
|
||||
disease_pos_all = np.flatnonzero(is_disease)
|
||||
disease_times_all = times_ins[disease_pos_all] if disease_pos_all.size > 0 else np.zeros(
|
||||
(0,), dtype=np.float64)
|
||||
disease_times_all = (
|
||||
times_ins[disease_pos_all]
|
||||
if disease_pos_all.size > 0
|
||||
else np.zeros((0,), dtype=np.float64)
|
||||
)
|
||||
|
||||
for b in range(len(age_bins_days) - 1):
|
||||
lo = age_bins_days[b]
|
||||
hi = age_bins_days[b + 1]
|
||||
eps = 1e-6
|
||||
out: List[EvalRecord] = []
|
||||
for b in range(len(age_bins_days_local) - 1):
|
||||
lo = float(age_bins_days_local[b])
|
||||
hi = float(age_bins_days_local[b + 1])
|
||||
|
||||
# Inclusion rule:
|
||||
# 1) DOA <= bin_upper
|
||||
@@ -359,7 +373,7 @@ def build_event_driven_records(
|
||||
if cand_pos.size == 0:
|
||||
continue
|
||||
|
||||
cutoff_pos = int(rng.choice(cand_pos))
|
||||
cutoff_pos = int(rng_local.choice(cand_pos))
|
||||
t0_days = float(times_ins[cutoff_pos])
|
||||
|
||||
# Future disease events strictly after t0
|
||||
@@ -376,13 +390,150 @@ def build_event_driven_records(
|
||||
future_causes = (
|
||||
future_tokens - N_TECH_TOKENS).astype(np.int64)
|
||||
future_dt_years_arr = (
|
||||
(future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32)
|
||||
(future_times_days - t0_days) / DAYS_PER_YEAR
|
||||
).astype(np.float32)
|
||||
|
||||
# next-event = minimal time > t0 (tie broken by earliest position)
|
||||
next_idx = int(np.argmin(future_times_days))
|
||||
next_cause = int(future_causes[next_idx])
|
||||
next_dt_years = float(future_dt_years_arr[next_idx])
|
||||
|
||||
out.append(
|
||||
EvalRecord(
|
||||
subset_idx=int(subset_idx),
|
||||
doa_days=float(doa_days),
|
||||
t0_days=float(t0_days),
|
||||
cutoff_pos=int(cutoff_pos),
|
||||
next_event_cause=next_cause,
|
||||
next_event_dt_years=next_dt_years,
|
||||
lifetime_causes=lifetime_causes,
|
||||
future_causes=future_causes,
|
||||
future_dt_years=future_dt_years_arr,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
def _process_chunk(
|
||||
chunk_indices: Sequence[int],
|
||||
*,
|
||||
age_bins_days_local: Sequence[float],
|
||||
seed_local: int,
|
||||
) -> List[EvalRecord]:
|
||||
out: List[EvalRecord] = []
|
||||
for subset_idx in chunk_indices:
|
||||
# Ensure each subject has its own deterministic RNG stream, so parallel
|
||||
# workers do not share identical seeds.
|
||||
ss = np.random.SeedSequence([int(seed_local), int(subset_idx)])
|
||||
rng_local = np.random.default_rng(ss)
|
||||
out.extend(
|
||||
_build_records_for_index(
|
||||
int(subset_idx),
|
||||
age_bins_days_local=age_bins_days_local,
|
||||
rng_local=rng_local,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
n = int(len(subset))
|
||||
chunks = _iter_chunks(n, int(chunk_size))
|
||||
|
||||
do_parallel = (
|
||||
int(n_jobs) != 1
|
||||
and Parallel is not None
|
||||
and delayed is not None
|
||||
and n > 0
|
||||
)
|
||||
|
||||
if do_parallel:
|
||||
# Note: on Windows, process-based parallelism may require the underlying
|
||||
# dataset to be pickleable. `prefer="threads"` is the default for safety.
|
||||
parts = Parallel(n_jobs=int(n_jobs), prefer=str(prefer), batch_size=1)(
|
||||
delayed(_process_chunk)(
|
||||
chunk,
|
||||
age_bins_days_local=age_bins_days,
|
||||
seed_local=int(seed),
|
||||
)
|
||||
for chunk in chunks
|
||||
)
|
||||
records = [r for part in parts for r in part]
|
||||
return records
|
||||
|
||||
# Sequential (preserve prior behavior/progress reporting)
|
||||
rng = np.random.default_rng(seed)
|
||||
records: List[EvalRecord] = []
|
||||
eps = 1e-6
|
||||
for subset_idx in _progress(
|
||||
range(len(subset)),
|
||||
enabled=show_progress,
|
||||
desc="Building eval records",
|
||||
total=len(subset),
|
||||
):
|
||||
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
|
||||
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
|
||||
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
|
||||
|
||||
doa_pos = np.flatnonzero(codes_ins == 1)
|
||||
if doa_pos.size == 0:
|
||||
raise ValueError("Expected DOA token (code=1) in event sequence")
|
||||
doa_days = float(times_ins[int(doa_pos[0])])
|
||||
|
||||
is_disease = codes_ins >= N_TECH_TOKENS
|
||||
|
||||
if np.any(is_disease):
|
||||
lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype(
|
||||
np.int64, copy=False
|
||||
)
|
||||
lifetime_causes = np.unique(lifetime_causes)
|
||||
else:
|
||||
lifetime_causes = np.zeros((0,), dtype=np.int64)
|
||||
|
||||
disease_pos_all = np.flatnonzero(is_disease)
|
||||
disease_times_all = (
|
||||
times_ins[disease_pos_all]
|
||||
if disease_pos_all.size > 0
|
||||
else np.zeros((0,), dtype=np.float64)
|
||||
)
|
||||
|
||||
for b in range(len(age_bins_days) - 1):
|
||||
lo = age_bins_days[b]
|
||||
hi = age_bins_days[b + 1]
|
||||
if not (doa_days <= hi):
|
||||
continue
|
||||
if disease_pos_all.size == 0:
|
||||
continue
|
||||
|
||||
in_bin = (
|
||||
(disease_times_all >= lo)
|
||||
& (disease_times_all < hi)
|
||||
& (disease_times_all >= doa_days)
|
||||
)
|
||||
cand_pos = disease_pos_all[in_bin]
|
||||
if cand_pos.size == 0:
|
||||
continue
|
||||
|
||||
cutoff_pos = int(rng.choice(cand_pos))
|
||||
t0_days = float(times_ins[cutoff_pos])
|
||||
|
||||
future_mask = (times_ins > (t0_days + eps)) & is_disease
|
||||
future_pos = np.flatnonzero(future_mask)
|
||||
if future_pos.size == 0:
|
||||
next_cause = None
|
||||
next_dt_years = None
|
||||
future_causes = np.zeros((0,), dtype=np.int64)
|
||||
future_dt_years_arr = np.zeros((0,), dtype=np.float32)
|
||||
else:
|
||||
future_times_days = times_ins[future_pos]
|
||||
future_tokens = codes_ins[future_pos]
|
||||
future_causes = (
|
||||
future_tokens - N_TECH_TOKENS).astype(np.int64)
|
||||
future_dt_years_arr = (
|
||||
(future_times_days - t0_days) / DAYS_PER_YEAR
|
||||
).astype(np.float32)
|
||||
|
||||
next_idx = int(np.argmin(future_times_days))
|
||||
next_cause = int(future_causes[next_idx])
|
||||
next_dt_years = float(future_dt_years_arr[next_idx])
|
||||
|
||||
records.append(
|
||||
EvalRecord(
|
||||
subset_idx=int(subset_idx),
|
||||
|
||||
Reference in New Issue
Block a user