feat: Optimize AUC evaluation with parallel processing

This commit is contained in:
2025-10-20 16:16:50 +08:00
parent 8f44018bae
commit 88cccdad2e
2 changed files with 79 additions and 66 deletions

View File

@@ -9,6 +9,7 @@ import numpy as np
import argparse
from utils import load_model, get_batch, PatientEventDataset
from pathlib import Path
from joblib import Parallel, delayed
def auc(x1, x2):
@@ -214,7 +215,7 @@ def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5),
return None
# For controls, we need to exclude cases with disease k
wc = np.where((d[2] != k) * (~(d[2] == k).any(-1))[..., None])
wc = np.where((d[2] != k) & (~(d[2] == k).any(-1))[..., None])
wall = (np.concatenate([wk[0], wc[0]]), np.concatenate([wk[1], wc[1]])) # All cases and controls
@@ -224,36 +225,38 @@ def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5),
else:
pred_idx = precomputed_idx[wall] # It's actually much faster to precompute this
valid_indices = pred_idx != -1
pred_idx = pred_idx[valid_indices]
wall = (wall[0][valid_indices], wall[1][valid_indices])
z = d[1][(wall[0], pred_idx)] # Times of the tokens for prediction
z = z[pred_idx != -1]
zk = d[3][wall] # Target times
zk = zk[pred_idx != -1]
# x = np.exp(p[..., j][(wall[0], pred_idx)]) * 365.25
# x = 1 - np.exp(-x * age_step) # the function is monotinic, so we don't need to do this for the AUC
x = p[..., j][(wall[0], pred_idx)]
x = x[pred_idx != -1]
wk = (wk[0][pred_idx[: len(wk[0])] != -1], wk[1][pred_idx[: len(wk[0])] != -1])
p_idx = wall[0][pred_idx != -1]
p_idx = wall[0]
out = []
for i, aa in enumerate(age_groups):
a = np.logical_and(z / 365.25 >= aa, z / 365.25 < aa + age_step)
# Optionally, add extra filtering on the time difference, for example:
# a *= (zk - z < 365.25)
selected_groups = p_idx[a]
perm = np.random.permutation(len(selected_groups))
_, indices = np.unique(selected_groups[perm], return_index=True)
indices = perm[indices]
selected = np.zeros(np.sum(a), dtype=bool)
selected[indices] = True
a[a] = selected
a = (z / 365.25 >= aa) & (z / 365.25 < aa + age_step)
if not np.any(a):
continue
control = x[len(wk[0]) :][a[len(wk[0]) :]]
case = x[: len(wk[0])][a[: len(wk[0])]]
selected_groups = p_idx[a]
_, unique_indices = np.unique(selected_groups, return_index=True)
a_filtered = a[a]
a_filtered[:] = False
a_filtered[unique_indices] = True
a[a] = a_filtered
is_case = np.zeros_like(x, dtype=bool)
is_case[:len(wk[0])] = True
control = x[~is_case & a]
case = x[is_case & a]
if len(control) == 0 or len(case) == 0:
continue
@@ -282,6 +285,35 @@ def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5),
return out
def process_chunk(disease_chunk_idx, diseases_chunk, d100k, p100k, pred_idx_precompute, age_groups, offset, n_bootstrap):
all_aucs = []
for sex, sex_idx in [("female", 2), ("male", 3)]:
sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy()
p_sex = p100k[sex_mask]
d100k_sex = [d_.cpu().detach().numpy()[sex_mask] for d_ in d100k]
precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy()
for j, k in tqdm(
list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}"
):
out = get_calibration_auc(
j,
k,
d100k_sex,
p_sex,
age_groups=age_groups,
offset=offset,
precomputed_idx=precomputed_idx_subset,
n_bootstrap=n_bootstrap,
use_delong=True,
)
if out is None:
continue
for out_item in out:
out_item["sex"] = sex
all_aucs.append(out_item)
return all_aucs
# New internal function that performs the AUC evaluation pipeline.
def evaluate_auc_pipeline(
model,
@@ -293,11 +325,12 @@ def evaluate_auc_pipeline(
disease_chunk_size=200,
age_groups=np.arange(40, 80, 5),
offset=0.1,
batch_size=128,
batch_size=256,
device="cpu",
seed=1337,
n_bootstrap=1,
meta_info={},
n_jobs=-1,
):
"""
Runs the AUC evaluation pipeline.
@@ -316,6 +349,7 @@ def evaluate_auc_pipeline(
device (str): Device identifier.
seed (int): Random seed for reproducibility.
n_bootstrap (int): Number of bootstrap samples. (1 for no bootstrap)
n_jobs (int): Number of parallel jobs to run.
Returns:
tuple: (df_auc_unpooled, df_auc, df_both) DataFrames.
"""
@@ -337,51 +371,27 @@ def evaluate_auc_pipeline(
# Precompute prediction indices for calibration
pred_idx_precompute = (d100k[1][:, :, np.newaxis] < d100k[3][:, np.newaxis, :] - offset).sum(1) - 1
all_aucs = []
tqdm_options = {"desc": "Processing disease chunks", "total": len(diseases_chunks)}
for disease_chunk_idx, diseases_chunk in tqdm(enumerate(diseases_chunks), **tqdm_options):
p100k = []
model.to(device)
with torch.no_grad():
# Process the evaluation data in batches
for dd in tqdm(
zip(*[torch.split(x, batch_size) for x in d100k]),
desc=f"Model inference, chunk {disease_chunk_idx}",
total=d100k[0].shape[0] // batch_size + 1,
):
dd = [x.to(device) for x in dd]
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
# Keep only the columns corresponding to the current disease chunk
p100k.append(outputs[:, :, diseases_chunk].astype("float16")) # enough to store logits, but not rates
p100k = np.vstack(p100k)
p100k = []
model.to(device)
with torch.no_grad():
for dd in tqdm(
zip(*[torch.split(x, batch_size) for x in d100k]),
desc=f"Model inference",
total=d100k[0].shape[0] // batch_size + 1,
):
dd = [x.to(device) for x in dd]
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
p100k.append(outputs.astype("float16"))
p100k = np.vstack(p100k)
results = Parallel(n_jobs=n_jobs)(
delayed(process_chunk)(
disease_chunk_idx, diseases_chunk, d100k, p100k[:, :, diseases_chunk], pred_idx_precompute, age_groups, offset, n_bootstrap
)
for disease_chunk_idx, diseases_chunk in enumerate(diseases_chunks)
)
# Loop over each disease (token) in the current chunk, sexes separately
for sex, sex_idx in [("female", 2), ("male", 3)]:
sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy()
p_sex = p100k[sex_mask]
d100k_sex = [d_[sex_mask].cpu().detach().numpy() for d_ in d100k]
precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy()
for j, k in tqdm(
list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}"
):
# Get calibration AUC for the current disease token.
out = get_calibration_auc(
j,
k,
d100k_sex,
p_sex,
age_groups=age_groups,
offset=offset,
precomputed_idx=precomputed_idx_subset,
n_bootstrap=n_bootstrap,
use_delong=True,
)
if out is None:
# print(f"No data for disease {k} and sex {sex}")
continue
for out_item in out:
out_item["sex"] = sex
all_aucs.append(out_item)
all_aucs = [item for sublist in results for item in sublist]
df_auc_unpooled = pd.DataFrame(all_aucs)
@@ -427,6 +437,7 @@ def main():
# Optional filtering/chunking parameters:
parser.add_argument("--filter_min_total", type=int, default=100, help="Minimum total count to filter tokens")
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run")
args = parser.parse_args()
output_path = './'
@@ -475,6 +486,7 @@ def main():
device=device,
seed=seed,
n_bootstrap=args.n_bootstrap,
n_jobs=args.n_jobs,
)

View File

@@ -2,3 +2,4 @@ torch
numpy
tqdm
matplotlib
joblib