Add --max_cpu_cores argument for parallel processing in evaluation scripts
This commit is contained in:
@@ -77,6 +77,12 @@ def parse_args() -> argparse.Namespace:
|
|||||||
)
|
)
|
||||||
p.add_argument("--batch_size", type=int, default=256)
|
p.add_argument("--batch_size", type=int, default=256)
|
||||||
p.add_argument("--num_workers", type=int, default=0)
|
p.add_argument("--num_workers", type=int, default=0)
|
||||||
|
p.add_argument(
|
||||||
|
"--max_cpu_cores",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Maximum number of CPU cores to use for parallel data construction.",
|
||||||
|
)
|
||||||
p.add_argument("--seed", type=int, default=0)
|
p.add_argument("--seed", type=int, default=0)
|
||||||
p.add_argument("--min_pos", type=int, default=20)
|
p.add_argument("--min_pos", type=int, default=20)
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
@@ -171,6 +177,7 @@ def main() -> None:
|
|||||||
age_bins_years=age_bins_years,
|
age_bins_years=age_bins_years,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
show_progress=show_progress,
|
show_progress=show_progress,
|
||||||
|
n_jobs=int(args.max_cpu_cores),
|
||||||
)
|
)
|
||||||
|
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
|
|||||||
@@ -50,6 +50,12 @@ def parse_args() -> argparse.Namespace:
|
|||||||
)
|
)
|
||||||
p.add_argument("--batch_size", type=int, default=256)
|
p.add_argument("--batch_size", type=int, default=256)
|
||||||
p.add_argument("--num_workers", type=int, default=0)
|
p.add_argument("--num_workers", type=int, default=0)
|
||||||
|
p.add_argument(
|
||||||
|
"--max_cpu_cores",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Maximum number of CPU cores to use for parallel data construction.",
|
||||||
|
)
|
||||||
p.add_argument("--seed", type=int, default=0)
|
p.add_argument("--seed", type=int, default=0)
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
"--min_pos",
|
"--min_pos",
|
||||||
@@ -96,6 +102,45 @@ def _compute_next_event_auc_clean_control(
|
|||||||
dtype=np.int64,
|
dtype=np.int64,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pre-compute lifetime disease membership matrix for vectorized clean-control filtering.
|
||||||
|
# lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes.
|
||||||
|
# Use a sparse matrix when SciPy is available to keep memory bounded for large K.
|
||||||
|
row_parts: List[np.ndarray] = []
|
||||||
|
col_parts: List[np.ndarray] = []
|
||||||
|
for i, r in enumerate(records):
|
||||||
|
causes = getattr(r, "lifetime_causes", None)
|
||||||
|
if causes is None:
|
||||||
|
continue
|
||||||
|
causes = np.asarray(causes, dtype=np.int64)
|
||||||
|
if causes.size == 0:
|
||||||
|
continue
|
||||||
|
# Keep only valid cause ids.
|
||||||
|
m_valid = (causes >= 0) & (causes < K)
|
||||||
|
if not np.any(m_valid):
|
||||||
|
continue
|
||||||
|
causes = causes[m_valid]
|
||||||
|
row_parts.append(np.full((causes.size,), i, dtype=np.int32))
|
||||||
|
col_parts.append(causes.astype(np.int32, copy=False))
|
||||||
|
|
||||||
|
try:
|
||||||
|
import scipy.sparse as sp # type: ignore
|
||||||
|
|
||||||
|
if row_parts:
|
||||||
|
rows = np.concatenate(row_parts, axis=0)
|
||||||
|
cols = np.concatenate(col_parts, axis=0)
|
||||||
|
data = np.ones((rows.size,), dtype=bool)
|
||||||
|
lifetime_matrix = sp.csc_matrix(
|
||||||
|
(data, (rows, cols)), shape=(n_records, K))
|
||||||
|
else:
|
||||||
|
lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool)
|
||||||
|
lifetime_is_sparse = True
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
lifetime_matrix = np.zeros((n_records, K), dtype=bool)
|
||||||
|
for rows, cols in zip(row_parts, col_parts):
|
||||||
|
lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype(
|
||||||
|
np.int64, copy=False)] = True
|
||||||
|
lifetime_is_sparse = False
|
||||||
|
|
||||||
auc = np.full((K,), np.nan, dtype=np.float64)
|
auc = np.full((K,), np.nan, dtype=np.float64)
|
||||||
var = np.full((K,), np.nan, dtype=np.float64)
|
var = np.full((K,), np.nan, dtype=np.float64)
|
||||||
n_case = np.zeros((K,), dtype=np.int64)
|
n_case = np.zeros((K,), dtype=np.int64)
|
||||||
@@ -109,12 +154,13 @@ def _compute_next_event_auc_clean_control(
|
|||||||
# Clean controls: not next-event k AND never had k in their lifetime history.
|
# Clean controls: not next-event k AND never had k in their lifetime history.
|
||||||
control_mask = y_next != k
|
control_mask = y_next != k
|
||||||
if np.any(control_mask):
|
if np.any(control_mask):
|
||||||
clean = np.fromiter(
|
if lifetime_is_sparse:
|
||||||
((k not in rec.lifetime_causes) for rec in records),
|
had_k = np.asarray(lifetime_matrix.getcol(
|
||||||
dtype=bool,
|
k).toarray().ravel(), dtype=bool)
|
||||||
count=n_records,
|
else:
|
||||||
)
|
had_k = lifetime_matrix[:, k]
|
||||||
control_mask = control_mask & clean
|
is_clean = ~had_k
|
||||||
|
control_mask = control_mask & is_clean
|
||||||
|
|
||||||
cs = scores[case_mask, k]
|
cs = scores[case_mask, k]
|
||||||
hs = scores[control_mask, k]
|
hs = scores[control_mask, k]
|
||||||
@@ -156,6 +202,7 @@ def main() -> None:
|
|||||||
age_bins_years=age_bins_years,
|
age_bins_years=age_bins_years,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
show_progress=show_progress,
|
show_progress=show_progress,
|
||||||
|
n_jobs=int(args.max_cpu_cores),
|
||||||
)
|
)
|
||||||
|
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ Options:
|
|||||||
|
|
||||||
Common eval args:
|
Common eval args:
|
||||||
Anything after `--` is appended to BOTH evaluation commands.
|
Anything after `--` is appended to BOTH evaluation commands.
|
||||||
Use this only for flags supported by BOTH scripts (e.g. --batch_size, --num_workers, --seed, --min_pos, --no_tqdm).
|
Use this only for flags supported by BOTH scripts (e.g. --batch_size, --num_workers, --max_cpu_cores, --seed, --min_pos, --no_tqdm).
|
||||||
|
|
||||||
Per-eval args:
|
Per-eval args:
|
||||||
For eval-specific flags (e.g. evaluate_horizon.py --topk_list / --workload_fracs), use --horizon-args-file.
|
For eval-specific flags (e.g. evaluate_horizon.py --topk_list / --workload_fracs), use --horizon-args-file.
|
||||||
@@ -41,6 +41,8 @@ Examples:
|
|||||||
./run_evaluations_multi_gpu.sh --gpus 0,1
|
./run_evaluations_multi_gpu.sh --gpus 0,1
|
||||||
./run_evaluations_multi_gpu.sh --gpus 0,1 --runs-root runs --pattern "delphi_*" \
|
./run_evaluations_multi_gpu.sh --gpus 0,1 --runs-root runs --pattern "delphi_*" \
|
||||||
--horizons 0.25 0.5 1 2 5 10 --age-bins 40 45 50 55 60 65 70 75 inf -- --batch_size 512 --num_workers 4
|
--horizons 0.25 0.5 1 2 5 10 --age-bins 40 45 50 55 60 65 70 75 inf -- --batch_size 512 --num_workers 4
|
||||||
|
./run_evaluations_multi_gpu.sh --gpus 0,1 --runs-root runs --pattern "delphi_*" \
|
||||||
|
-- --batch_size 512 --num_workers 4 --max_cpu_cores -1
|
||||||
|
|
||||||
USAGE
|
USAGE
|
||||||
}
|
}
|
||||||
|
|||||||
191
utils.py
191
utils.py
@@ -15,6 +15,12 @@ try:
|
|||||||
except Exception: # pragma: no cover
|
except Exception: # pragma: no cover
|
||||||
_tqdm = None
|
_tqdm = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from joblib import Parallel, delayed # type: ignore
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
Parallel = None
|
||||||
|
delayed = None
|
||||||
|
|
||||||
from dataset import HealthDataset
|
from dataset import HealthDataset
|
||||||
from losses import (
|
from losses import (
|
||||||
DiscreteTimeCIFNLLLoss,
|
DiscreteTimeCIFNLLLoss,
|
||||||
@@ -290,6 +296,9 @@ def build_event_driven_records(
|
|||||||
age_bins_years: Sequence[float],
|
age_bins_years: Sequence[float],
|
||||||
seed: int,
|
seed: int,
|
||||||
show_progress: bool = False,
|
show_progress: bool = False,
|
||||||
|
n_jobs: int = 1,
|
||||||
|
chunk_size: int = 256,
|
||||||
|
prefer: str = "threads",
|
||||||
) -> List[EvalRecord]:
|
) -> List[EvalRecord]:
|
||||||
if len(age_bins_years) < 2:
|
if len(age_bins_years) < 2:
|
||||||
raise ValueError("age_bins must have at least 2 boundaries")
|
raise ValueError("age_bins must have at least 2 boundaries")
|
||||||
@@ -298,20 +307,20 @@ def build_event_driven_records(
|
|||||||
if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)):
|
if any(age_bins_days[i] >= age_bins_days[i + 1] for i in range(len(age_bins_days) - 1)):
|
||||||
raise ValueError("age_bins must be strictly increasing")
|
raise ValueError("age_bins must be strictly increasing")
|
||||||
|
|
||||||
rng = np.random.default_rng(seed)
|
def _iter_chunks(n: int, size: int) -> List[np.ndarray]:
|
||||||
|
if size <= 0:
|
||||||
|
raise ValueError("chunk_size must be >= 1")
|
||||||
|
if n == 0:
|
||||||
|
return []
|
||||||
|
idx = np.arange(n, dtype=np.int64)
|
||||||
|
return [idx[i:i + size] for i in range(0, n, size)]
|
||||||
|
|
||||||
records: List[EvalRecord] = []
|
def _build_records_for_index(
|
||||||
|
subset_idx: int,
|
||||||
# Build records exclusively from the provided subset.
|
*,
|
||||||
# We intentionally avoid reading from subset.dataset internals so the
|
age_bins_days_local: Sequence[float],
|
||||||
# evaluation pipeline does not depend on the full dataset object.
|
rng_local: np.random.Generator,
|
||||||
eps = 1e-6
|
) -> List[EvalRecord]:
|
||||||
for subset_idx in _progress(
|
|
||||||
range(len(subset)),
|
|
||||||
enabled=show_progress,
|
|
||||||
desc="Building eval records",
|
|
||||||
total=len(subset),
|
|
||||||
):
|
|
||||||
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
|
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
|
||||||
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
|
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
|
||||||
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
|
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
|
||||||
@@ -333,12 +342,17 @@ def build_event_driven_records(
|
|||||||
lifetime_causes = np.zeros((0,), dtype=np.int64)
|
lifetime_causes = np.zeros((0,), dtype=np.int64)
|
||||||
|
|
||||||
disease_pos_all = np.flatnonzero(is_disease)
|
disease_pos_all = np.flatnonzero(is_disease)
|
||||||
disease_times_all = times_ins[disease_pos_all] if disease_pos_all.size > 0 else np.zeros(
|
disease_times_all = (
|
||||||
(0,), dtype=np.float64)
|
times_ins[disease_pos_all]
|
||||||
|
if disease_pos_all.size > 0
|
||||||
|
else np.zeros((0,), dtype=np.float64)
|
||||||
|
)
|
||||||
|
|
||||||
for b in range(len(age_bins_days) - 1):
|
eps = 1e-6
|
||||||
lo = age_bins_days[b]
|
out: List[EvalRecord] = []
|
||||||
hi = age_bins_days[b + 1]
|
for b in range(len(age_bins_days_local) - 1):
|
||||||
|
lo = float(age_bins_days_local[b])
|
||||||
|
hi = float(age_bins_days_local[b + 1])
|
||||||
|
|
||||||
# Inclusion rule:
|
# Inclusion rule:
|
||||||
# 1) DOA <= bin_upper
|
# 1) DOA <= bin_upper
|
||||||
@@ -359,7 +373,7 @@ def build_event_driven_records(
|
|||||||
if cand_pos.size == 0:
|
if cand_pos.size == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cutoff_pos = int(rng.choice(cand_pos))
|
cutoff_pos = int(rng_local.choice(cand_pos))
|
||||||
t0_days = float(times_ins[cutoff_pos])
|
t0_days = float(times_ins[cutoff_pos])
|
||||||
|
|
||||||
# Future disease events strictly after t0
|
# Future disease events strictly after t0
|
||||||
@@ -376,13 +390,150 @@ def build_event_driven_records(
|
|||||||
future_causes = (
|
future_causes = (
|
||||||
future_tokens - N_TECH_TOKENS).astype(np.int64)
|
future_tokens - N_TECH_TOKENS).astype(np.int64)
|
||||||
future_dt_years_arr = (
|
future_dt_years_arr = (
|
||||||
(future_times_days - t0_days) / DAYS_PER_YEAR).astype(np.float32)
|
(future_times_days - t0_days) / DAYS_PER_YEAR
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
# next-event = minimal time > t0 (tie broken by earliest position)
|
# next-event = minimal time > t0 (tie broken by earliest position)
|
||||||
next_idx = int(np.argmin(future_times_days))
|
next_idx = int(np.argmin(future_times_days))
|
||||||
next_cause = int(future_causes[next_idx])
|
next_cause = int(future_causes[next_idx])
|
||||||
next_dt_years = float(future_dt_years_arr[next_idx])
|
next_dt_years = float(future_dt_years_arr[next_idx])
|
||||||
|
|
||||||
|
out.append(
|
||||||
|
EvalRecord(
|
||||||
|
subset_idx=int(subset_idx),
|
||||||
|
doa_days=float(doa_days),
|
||||||
|
t0_days=float(t0_days),
|
||||||
|
cutoff_pos=int(cutoff_pos),
|
||||||
|
next_event_cause=next_cause,
|
||||||
|
next_event_dt_years=next_dt_years,
|
||||||
|
lifetime_causes=lifetime_causes,
|
||||||
|
future_causes=future_causes,
|
||||||
|
future_dt_years=future_dt_years_arr,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _process_chunk(
|
||||||
|
chunk_indices: Sequence[int],
|
||||||
|
*,
|
||||||
|
age_bins_days_local: Sequence[float],
|
||||||
|
seed_local: int,
|
||||||
|
) -> List[EvalRecord]:
|
||||||
|
out: List[EvalRecord] = []
|
||||||
|
for subset_idx in chunk_indices:
|
||||||
|
# Ensure each subject has its own deterministic RNG stream, so parallel
|
||||||
|
# workers do not share identical seeds.
|
||||||
|
ss = np.random.SeedSequence([int(seed_local), int(subset_idx)])
|
||||||
|
rng_local = np.random.default_rng(ss)
|
||||||
|
out.extend(
|
||||||
|
_build_records_for_index(
|
||||||
|
int(subset_idx),
|
||||||
|
age_bins_days_local=age_bins_days_local,
|
||||||
|
rng_local=rng_local,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
n = int(len(subset))
|
||||||
|
chunks = _iter_chunks(n, int(chunk_size))
|
||||||
|
|
||||||
|
do_parallel = (
|
||||||
|
int(n_jobs) != 1
|
||||||
|
and Parallel is not None
|
||||||
|
and delayed is not None
|
||||||
|
and n > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_parallel:
|
||||||
|
# Note: on Windows, process-based parallelism may require the underlying
|
||||||
|
# dataset to be pickleable. `prefer="threads"` is the default for safety.
|
||||||
|
parts = Parallel(n_jobs=int(n_jobs), prefer=str(prefer), batch_size=1)(
|
||||||
|
delayed(_process_chunk)(
|
||||||
|
chunk,
|
||||||
|
age_bins_days_local=age_bins_days,
|
||||||
|
seed_local=int(seed),
|
||||||
|
)
|
||||||
|
for chunk in chunks
|
||||||
|
)
|
||||||
|
records = [r for part in parts for r in part]
|
||||||
|
return records
|
||||||
|
|
||||||
|
# Sequential (preserve prior behavior/progress reporting)
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
records: List[EvalRecord] = []
|
||||||
|
eps = 1e-6
|
||||||
|
for subset_idx in _progress(
|
||||||
|
range(len(subset)),
|
||||||
|
enabled=show_progress,
|
||||||
|
desc="Building eval records",
|
||||||
|
total=len(subset),
|
||||||
|
):
|
||||||
|
event_tensor, time_tensor, _, _, _ = subset[int(subset_idx)]
|
||||||
|
codes_ins = event_tensor.detach().cpu().numpy().astype(np.int64, copy=False)
|
||||||
|
times_ins = time_tensor.detach().cpu().numpy().astype(np.float64, copy=False)
|
||||||
|
|
||||||
|
doa_pos = np.flatnonzero(codes_ins == 1)
|
||||||
|
if doa_pos.size == 0:
|
||||||
|
raise ValueError("Expected DOA token (code=1) in event sequence")
|
||||||
|
doa_days = float(times_ins[int(doa_pos[0])])
|
||||||
|
|
||||||
|
is_disease = codes_ins >= N_TECH_TOKENS
|
||||||
|
|
||||||
|
if np.any(is_disease):
|
||||||
|
lifetime_causes = (codes_ins[is_disease] - N_TECH_TOKENS).astype(
|
||||||
|
np.int64, copy=False
|
||||||
|
)
|
||||||
|
lifetime_causes = np.unique(lifetime_causes)
|
||||||
|
else:
|
||||||
|
lifetime_causes = np.zeros((0,), dtype=np.int64)
|
||||||
|
|
||||||
|
disease_pos_all = np.flatnonzero(is_disease)
|
||||||
|
disease_times_all = (
|
||||||
|
times_ins[disease_pos_all]
|
||||||
|
if disease_pos_all.size > 0
|
||||||
|
else np.zeros((0,), dtype=np.float64)
|
||||||
|
)
|
||||||
|
|
||||||
|
for b in range(len(age_bins_days) - 1):
|
||||||
|
lo = age_bins_days[b]
|
||||||
|
hi = age_bins_days[b + 1]
|
||||||
|
if not (doa_days <= hi):
|
||||||
|
continue
|
||||||
|
if disease_pos_all.size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
in_bin = (
|
||||||
|
(disease_times_all >= lo)
|
||||||
|
& (disease_times_all < hi)
|
||||||
|
& (disease_times_all >= doa_days)
|
||||||
|
)
|
||||||
|
cand_pos = disease_pos_all[in_bin]
|
||||||
|
if cand_pos.size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cutoff_pos = int(rng.choice(cand_pos))
|
||||||
|
t0_days = float(times_ins[cutoff_pos])
|
||||||
|
|
||||||
|
future_mask = (times_ins > (t0_days + eps)) & is_disease
|
||||||
|
future_pos = np.flatnonzero(future_mask)
|
||||||
|
if future_pos.size == 0:
|
||||||
|
next_cause = None
|
||||||
|
next_dt_years = None
|
||||||
|
future_causes = np.zeros((0,), dtype=np.int64)
|
||||||
|
future_dt_years_arr = np.zeros((0,), dtype=np.float32)
|
||||||
|
else:
|
||||||
|
future_times_days = times_ins[future_pos]
|
||||||
|
future_tokens = codes_ins[future_pos]
|
||||||
|
future_causes = (
|
||||||
|
future_tokens - N_TECH_TOKENS).astype(np.int64)
|
||||||
|
future_dt_years_arr = (
|
||||||
|
(future_times_days - t0_days) / DAYS_PER_YEAR
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
next_idx = int(np.argmin(future_times_days))
|
||||||
|
next_cause = int(future_causes[next_idx])
|
||||||
|
next_dt_years = float(future_dt_years_arr[next_idx])
|
||||||
|
|
||||||
records.append(
|
records.append(
|
||||||
EvalRecord(
|
EvalRecord(
|
||||||
subset_idx=int(subset_idx),
|
subset_idx=int(subset_idx),
|
||||||
|
|||||||
Reference in New Issue
Block a user