feat: Optimize AUC evaluation with parallel processing
This commit is contained in:
144
evaluate_auc.py
144
evaluate_auc.py
@@ -9,6 +9,7 @@ import numpy as np
|
||||
import argparse
|
||||
from utils import load_model, get_batch, PatientEventDataset
|
||||
from pathlib import Path
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
|
||||
def auc(x1, x2):
|
||||
@@ -214,7 +215,7 @@ def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5),
|
||||
return None
|
||||
|
||||
# For controls, we need to exclude cases with disease k
|
||||
wc = np.where((d[2] != k) * (~(d[2] == k).any(-1))[..., None])
|
||||
wc = np.where((d[2] != k) & (~(d[2] == k).any(-1))[..., None])
|
||||
|
||||
wall = (np.concatenate([wk[0], wc[0]]), np.concatenate([wk[1], wc[1]])) # All cases and controls
|
||||
|
||||
@@ -224,36 +225,38 @@ def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5),
|
||||
else:
|
||||
pred_idx = precomputed_idx[wall] # It's actually much faster to precompute this
|
||||
|
||||
valid_indices = pred_idx != -1
|
||||
pred_idx = pred_idx[valid_indices]
|
||||
wall = (wall[0][valid_indices], wall[1][valid_indices])
|
||||
|
||||
z = d[1][(wall[0], pred_idx)] # Times of the tokens for prediction
|
||||
z = z[pred_idx != -1]
|
||||
|
||||
zk = d[3][wall] # Target times
|
||||
zk = zk[pred_idx != -1]
|
||||
|
||||
# x = np.exp(p[..., j][(wall[0], pred_idx)]) * 365.25
|
||||
# x = 1 - np.exp(-x * age_step) # the function is monotinic, so we don't need to do this for the AUC
|
||||
x = p[..., j][(wall[0], pred_idx)]
|
||||
x = x[pred_idx != -1]
|
||||
|
||||
wk = (wk[0][pred_idx[: len(wk[0])] != -1], wk[1][pred_idx[: len(wk[0])] != -1])
|
||||
p_idx = wall[0][pred_idx != -1]
|
||||
p_idx = wall[0]
|
||||
|
||||
out = []
|
||||
|
||||
for i, aa in enumerate(age_groups):
|
||||
a = np.logical_and(z / 365.25 >= aa, z / 365.25 < aa + age_step)
|
||||
# Optionally, add extra filtering on the time difference, for example:
|
||||
# a *= (zk - z < 365.25)
|
||||
selected_groups = p_idx[a]
|
||||
perm = np.random.permutation(len(selected_groups))
|
||||
_, indices = np.unique(selected_groups[perm], return_index=True)
|
||||
indices = perm[indices]
|
||||
selected = np.zeros(np.sum(a), dtype=bool)
|
||||
selected[indices] = True
|
||||
a[a] = selected
|
||||
a = (z / 365.25 >= aa) & (z / 365.25 < aa + age_step)
|
||||
|
||||
control = x[len(wk[0]) :][a[len(wk[0]) :]]
|
||||
case = x[: len(wk[0])][a[: len(wk[0])]]
|
||||
if not np.any(a):
|
||||
continue
|
||||
|
||||
selected_groups = p_idx[a]
|
||||
_, unique_indices = np.unique(selected_groups, return_index=True)
|
||||
|
||||
a_filtered = a[a]
|
||||
a_filtered[:] = False
|
||||
a_filtered[unique_indices] = True
|
||||
a[a] = a_filtered
|
||||
|
||||
is_case = np.zeros_like(x, dtype=bool)
|
||||
is_case[:len(wk[0])] = True
|
||||
|
||||
control = x[~is_case & a]
|
||||
case = x[is_case & a]
|
||||
|
||||
if len(control) == 0 or len(case) == 0:
|
||||
continue
|
||||
@@ -282,6 +285,35 @@ def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5),
|
||||
return out
|
||||
|
||||
|
||||
def process_chunk(disease_chunk_idx, diseases_chunk, d100k, p100k, pred_idx_precompute, age_groups, offset, n_bootstrap):
|
||||
all_aucs = []
|
||||
for sex, sex_idx in [("female", 2), ("male", 3)]:
|
||||
sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy()
|
||||
p_sex = p100k[sex_mask]
|
||||
d100k_sex = [d_.cpu().detach().numpy()[sex_mask] for d_ in d100k]
|
||||
precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy()
|
||||
for j, k in tqdm(
|
||||
list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}"
|
||||
):
|
||||
out = get_calibration_auc(
|
||||
j,
|
||||
k,
|
||||
d100k_sex,
|
||||
p_sex,
|
||||
age_groups=age_groups,
|
||||
offset=offset,
|
||||
precomputed_idx=precomputed_idx_subset,
|
||||
n_bootstrap=n_bootstrap,
|
||||
use_delong=True,
|
||||
)
|
||||
if out is None:
|
||||
continue
|
||||
for out_item in out:
|
||||
out_item["sex"] = sex
|
||||
all_aucs.append(out_item)
|
||||
return all_aucs
|
||||
|
||||
|
||||
# New internal function that performs the AUC evaluation pipeline.
|
||||
def evaluate_auc_pipeline(
|
||||
model,
|
||||
@@ -293,11 +325,12 @@ def evaluate_auc_pipeline(
|
||||
disease_chunk_size=200,
|
||||
age_groups=np.arange(40, 80, 5),
|
||||
offset=0.1,
|
||||
batch_size=128,
|
||||
batch_size=256,
|
||||
device="cpu",
|
||||
seed=1337,
|
||||
n_bootstrap=1,
|
||||
meta_info={},
|
||||
n_jobs=-1,
|
||||
):
|
||||
"""
|
||||
Runs the AUC evaluation pipeline.
|
||||
@@ -316,6 +349,7 @@ def evaluate_auc_pipeline(
|
||||
device (str): Device identifier.
|
||||
seed (int): Random seed for reproducibility.
|
||||
n_bootstrap (int): Number of bootstrap samples. (1 for no bootstrap)
|
||||
n_jobs (int): Number of parallel jobs to run.
|
||||
Returns:
|
||||
tuple: (df_auc_unpooled, df_auc, df_both) DataFrames.
|
||||
"""
|
||||
@@ -337,51 +371,27 @@ def evaluate_auc_pipeline(
|
||||
# Precompute prediction indices for calibration
|
||||
pred_idx_precompute = (d100k[1][:, :, np.newaxis] < d100k[3][:, np.newaxis, :] - offset).sum(1) - 1
|
||||
|
||||
all_aucs = []
|
||||
tqdm_options = {"desc": "Processing disease chunks", "total": len(diseases_chunks)}
|
||||
for disease_chunk_idx, diseases_chunk in tqdm(enumerate(diseases_chunks), **tqdm_options):
|
||||
p100k = []
|
||||
model.to(device)
|
||||
with torch.no_grad():
|
||||
# Process the evaluation data in batches
|
||||
for dd in tqdm(
|
||||
zip(*[torch.split(x, batch_size) for x in d100k]),
|
||||
desc=f"Model inference, chunk {disease_chunk_idx}",
|
||||
total=d100k[0].shape[0] // batch_size + 1,
|
||||
):
|
||||
dd = [x.to(device) for x in dd]
|
||||
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
|
||||
# Keep only the columns corresponding to the current disease chunk
|
||||
p100k.append(outputs[:, :, diseases_chunk].astype("float16")) # enough to store logits, but not rates
|
||||
p100k = np.vstack(p100k)
|
||||
p100k = []
|
||||
model.to(device)
|
||||
with torch.no_grad():
|
||||
for dd in tqdm(
|
||||
zip(*[torch.split(x, batch_size) for x in d100k]),
|
||||
desc=f"Model inference",
|
||||
total=d100k[0].shape[0] // batch_size + 1,
|
||||
):
|
||||
dd = [x.to(device) for x in dd]
|
||||
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
|
||||
p100k.append(outputs.astype("float16"))
|
||||
p100k = np.vstack(p100k)
|
||||
|
||||
# Loop over each disease (token) in the current chunk, sexes separately
|
||||
for sex, sex_idx in [("female", 2), ("male", 3)]:
|
||||
sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy()
|
||||
p_sex = p100k[sex_mask]
|
||||
d100k_sex = [d_[sex_mask].cpu().detach().numpy() for d_ in d100k]
|
||||
precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy()
|
||||
for j, k in tqdm(
|
||||
list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}"
|
||||
):
|
||||
# Get calibration AUC for the current disease token.
|
||||
out = get_calibration_auc(
|
||||
j,
|
||||
k,
|
||||
d100k_sex,
|
||||
p_sex,
|
||||
age_groups=age_groups,
|
||||
offset=offset,
|
||||
precomputed_idx=precomputed_idx_subset,
|
||||
n_bootstrap=n_bootstrap,
|
||||
use_delong=True,
|
||||
)
|
||||
if out is None:
|
||||
# print(f"No data for disease {k} and sex {sex}")
|
||||
continue
|
||||
for out_item in out:
|
||||
out_item["sex"] = sex
|
||||
all_aucs.append(out_item)
|
||||
results = Parallel(n_jobs=n_jobs)(
|
||||
delayed(process_chunk)(
|
||||
disease_chunk_idx, diseases_chunk, d100k, p100k[:, :, diseases_chunk], pred_idx_precompute, age_groups, offset, n_bootstrap
|
||||
)
|
||||
for disease_chunk_idx, diseases_chunk in enumerate(diseases_chunks)
|
||||
)
|
||||
|
||||
all_aucs = [item for sublist in results for item in sublist]
|
||||
|
||||
df_auc_unpooled = pd.DataFrame(all_aucs)
|
||||
|
||||
@@ -427,6 +437,7 @@ def main():
|
||||
# Optional filtering/chunking parameters:
|
||||
parser.add_argument("--filter_min_total", type=int, default=100, help="Minimum total count to filter tokens")
|
||||
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
|
||||
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run")
|
||||
args = parser.parse_args()
|
||||
|
||||
output_path = './'
|
||||
@@ -475,6 +486,7 @@ def main():
|
||||
device=device,
|
||||
seed=seed,
|
||||
n_bootstrap=args.n_bootstrap,
|
||||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -2,3 +2,4 @@ torch
|
||||
numpy
|
||||
tqdm
|
||||
matplotlib
|
||||
joblib
|
||||
|
Reference in New Issue
Block a user