Add --max_cpu_cores argument for parallel processing in evaluation scripts

This commit is contained in:
2026-01-17 23:53:24 +08:00
parent 248fb09c34
commit cfe7f88162
4 changed files with 234 additions and 27 deletions

View File

@@ -50,6 +50,12 @@ def parse_args() -> argparse.Namespace:
)
p.add_argument("--batch_size", type=int, default=256)
p.add_argument("--num_workers", type=int, default=0)
p.add_argument(
"--max_cpu_cores",
type=int,
default=-1,
help="Maximum number of CPU cores to use for parallel data construction.",
)
p.add_argument("--seed", type=int, default=0)
p.add_argument(
"--min_pos",
@@ -96,6 +102,45 @@ def _compute_next_event_auc_clean_control(
dtype=np.int64,
)
# Pre-compute lifetime disease membership matrix for vectorized clean-control filtering.
# lifetime_matrix[i, c] == True iff cause c is present in records[i].lifetime_causes.
# Use a sparse matrix when SciPy is available to keep memory bounded for large K.
row_parts: List[np.ndarray] = []
col_parts: List[np.ndarray] = []
for i, r in enumerate(records):
causes = getattr(r, "lifetime_causes", None)
if causes is None:
continue
causes = np.asarray(causes, dtype=np.int64)
if causes.size == 0:
continue
# Keep only valid cause ids.
m_valid = (causes >= 0) & (causes < K)
if not np.any(m_valid):
continue
causes = causes[m_valid]
row_parts.append(np.full((causes.size,), i, dtype=np.int32))
col_parts.append(causes.astype(np.int32, copy=False))
try:
import scipy.sparse as sp # type: ignore
if row_parts:
rows = np.concatenate(row_parts, axis=0)
cols = np.concatenate(col_parts, axis=0)
data = np.ones((rows.size,), dtype=bool)
lifetime_matrix = sp.csc_matrix(
(data, (rows, cols)), shape=(n_records, K))
else:
lifetime_matrix = sp.csc_matrix((n_records, K), dtype=bool)
lifetime_is_sparse = True
except Exception: # pragma: no cover
lifetime_matrix = np.zeros((n_records, K), dtype=bool)
for rows, cols in zip(row_parts, col_parts):
lifetime_matrix[rows.astype(np.int64, copy=False), cols.astype(
np.int64, copy=False)] = True
lifetime_is_sparse = False
auc = np.full((K,), np.nan, dtype=np.float64)
var = np.full((K,), np.nan, dtype=np.float64)
n_case = np.zeros((K,), dtype=np.int64)
@@ -109,12 +154,13 @@ def _compute_next_event_auc_clean_control(
# Clean controls: not next-event k AND never had k in their lifetime history.
control_mask = y_next != k
if np.any(control_mask):
clean = np.fromiter(
((k not in rec.lifetime_causes) for rec in records),
dtype=bool,
count=n_records,
)
control_mask = control_mask & clean
if lifetime_is_sparse:
had_k = np.asarray(lifetime_matrix.getcol(
k).toarray().ravel(), dtype=bool)
else:
had_k = lifetime_matrix[:, k]
is_clean = ~had_k
control_mask = control_mask & is_clean
cs = scores[case_mask, k]
hs = scores[control_mask, k]
@@ -156,6 +202,7 @@ def main() -> None:
age_bins_years=age_bins_years,
seed=args.seed,
show_progress=show_progress,
n_jobs=int(args.max_cpu_cores),
)
device = torch.device(args.device)