Add --max_cpu_cores argument for parallel processing in evaluation scripts
This commit is contained in:
@@ -50,6 +50,12 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
p.add_argument("--batch_size", type=int, default=256)
|
||||
p.add_argument("--num_workers", type=int, default=0)
|
||||
p.add_argument(
|
||||
"--max_cpu_cores",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Maximum number of CPU cores to use for parallel data construction.",
|
||||
)
|
||||
p.add_argument("--seed", type=int, default=0)
|
||||
p.add_argument(
|
||||
"--min_pos",
|
||||
@@ -96,6 +102,45 @@ def _compute_next_event_auc_clean_control(
|
||||
dtype=np.int64,
|
||||
)
|
||||
|
||||
# Pre-compute lifetime disease membership matrix for vectorized clean-control filtering.
|
||||
# lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes.
|
||||
# Use a sparse matrix when SciPy is available to keep memory bounded for large K.
|
||||
row_parts: List[np.ndarray] = []
|
||||
col_parts: List[np.ndarray] = []
|
||||
for i, r in enumerate(records):
|
||||
causes = getattr(r, "lifetime_causes", None)
|
||||
if causes is None:
|
||||
continue
|
||||
causes = np.asarray(causes, dtype=np.int64)
|
||||
if causes.size == 0:
|
||||
continue
|
||||
# Keep only valid cause ids.
|
||||
m_valid = (causes >= 0) & (causes < K)
|
||||
if not np.any(m_valid):
|
||||
continue
|
||||
causes = causes[m_valid]
|
||||
row_parts.append(np.full((causes.size,), i, dtype=np.int32))
|
||||
col_parts.append(causes.astype(np.int32, copy=False))
|
||||
|
||||
try:
|
||||
import scipy.sparse as sp # type: ignore
|
||||
|
||||
if row_parts:
|
||||
rows = np.concatenate(row_parts, axis=0)
|
||||
cols = np.concatenate(col_parts, axis=0)
|
||||
data = np.ones((rows.size,), dtype=bool)
|
||||
lifetime_matrix = sp.csc_matrix(
|
||||
(data, (rows, cols)), shape=(n_records, K))
|
||||
else:
|
||||
lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool)
|
||||
lifetime_is_sparse = True
|
||||
except Exception: # pragma: no cover
|
||||
lifetime_matrix = np.zeros((n_records, K), dtype=bool)
|
||||
for rows, cols in zip(row_parts, col_parts):
|
||||
lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype(
|
||||
np.int64, copy=False)] = True
|
||||
lifetime_is_sparse = False
|
||||
|
||||
auc = np.full((K,), np.nan, dtype=np.float64)
|
||||
var = np.full((K,), np.nan, dtype=np.float64)
|
||||
n_case = np.zeros((K,), dtype=np.int64)
|
||||
@@ -109,12 +154,13 @@ def _compute_next_event_auc_clean_control(
|
||||
# Clean controls: not next-event k AND never had k in their lifetime history.
|
||||
control_mask = y_next != k
|
||||
if np.any(control_mask):
|
||||
clean = np.fromiter(
|
||||
((k not in rec.lifetime_causes) for rec in records),
|
||||
dtype=bool,
|
||||
count=n_records,
|
||||
)
|
||||
control_mask = control_mask & clean
|
||||
if lifetime_is_sparse:
|
||||
had_k = np.asarray(lifetime_matrix.getcol(
|
||||
k).toarray().ravel(), dtype=bool)
|
||||
else:
|
||||
had_k = lifetime_matrix[:, k]
|
||||
is_clean = ~had_k
|
||||
control_mask = control_mask & is_clean
|
||||
|
||||
cs = scores[case_mask, k]
|
||||
hs = scores[control_mask, k]
|
||||
@@ -156,6 +202,7 @@ def main() -> None:
|
||||
age_bins_years=age_bins_years,
|
||||
seed=args.seed,
|
||||
show_progress=show_progress,
|
||||
n_jobs=int(args.max_cpu_cores),
|
||||
)
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
Reference in New Issue
Block a user