feat: Optimize AUC evaluation with parallel processing
This commit is contained in:
144
evaluate_auc.py
144
evaluate_auc.py
@@ -9,6 +9,7 @@ import numpy as np
|
|||||||
import argparse
|
import argparse
|
||||||
from utils import load_model, get_batch, PatientEventDataset
|
from utils import load_model, get_batch, PatientEventDataset
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from joblib import Parallel, delayed
|
||||||
|
|
||||||
|
|
||||||
def auc(x1, x2):
|
def auc(x1, x2):
|
||||||
@@ -214,7 +215,7 @@ def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5),
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# For controls, we need to exclude cases with disease k
|
# For controls, we need to exclude cases with disease k
|
||||||
wc = np.where((d[2] != k) * (~(d[2] == k).any(-1))[..., None])
|
wc = np.where((d[2] != k) & (~(d[2] == k).any(-1))[..., None])
|
||||||
|
|
||||||
wall = (np.concatenate([wk[0], wc[0]]), np.concatenate([wk[1], wc[1]])) # All cases and controls
|
wall = (np.concatenate([wk[0], wc[0]]), np.concatenate([wk[1], wc[1]])) # All cases and controls
|
||||||
|
|
||||||
@@ -224,36 +225,38 @@ def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5),
|
|||||||
else:
|
else:
|
||||||
pred_idx = precomputed_idx[wall] # It's actually much faster to precompute this
|
pred_idx = precomputed_idx[wall] # It's actually much faster to precompute this
|
||||||
|
|
||||||
|
valid_indices = pred_idx != -1
|
||||||
|
pred_idx = pred_idx[valid_indices]
|
||||||
|
wall = (wall[0][valid_indices], wall[1][valid_indices])
|
||||||
|
|
||||||
z = d[1][(wall[0], pred_idx)] # Times of the tokens for prediction
|
z = d[1][(wall[0], pred_idx)] # Times of the tokens for prediction
|
||||||
z = z[pred_idx != -1]
|
|
||||||
|
|
||||||
zk = d[3][wall] # Target times
|
zk = d[3][wall] # Target times
|
||||||
zk = zk[pred_idx != -1]
|
|
||||||
|
|
||||||
# x = np.exp(p[..., j][(wall[0], pred_idx)]) * 365.25
|
|
||||||
# x = 1 - np.exp(-x * age_step) # the function is monotinic, so we don't need to do this for the AUC
|
|
||||||
x = p[..., j][(wall[0], pred_idx)]
|
x = p[..., j][(wall[0], pred_idx)]
|
||||||
x = x[pred_idx != -1]
|
|
||||||
|
|
||||||
wk = (wk[0][pred_idx[: len(wk[0])] != -1], wk[1][pred_idx[: len(wk[0])] != -1])
|
p_idx = wall[0]
|
||||||
p_idx = wall[0][pred_idx != -1]
|
|
||||||
|
|
||||||
out = []
|
out = []
|
||||||
|
|
||||||
for i, aa in enumerate(age_groups):
|
for i, aa in enumerate(age_groups):
|
||||||
a = np.logical_and(z / 365.25 >= aa, z / 365.25 < aa + age_step)
|
a = (z / 365.25 >= aa) & (z / 365.25 < aa + age_step)
|
||||||
# Optionally, add extra filtering on the time difference, for example:
|
|
||||||
# a *= (zk - z < 365.25)
|
if not np.any(a):
|
||||||
selected_groups = p_idx[a]
|
continue
|
||||||
perm = np.random.permutation(len(selected_groups))
|
|
||||||
_, indices = np.unique(selected_groups[perm], return_index=True)
|
|
||||||
indices = perm[indices]
|
|
||||||
selected = np.zeros(np.sum(a), dtype=bool)
|
|
||||||
selected[indices] = True
|
|
||||||
a[a] = selected
|
|
||||||
|
|
||||||
control = x[len(wk[0]) :][a[len(wk[0]) :]]
|
selected_groups = p_idx[a]
|
||||||
case = x[: len(wk[0])][a[: len(wk[0])]]
|
_, unique_indices = np.unique(selected_groups, return_index=True)
|
||||||
|
|
||||||
|
a_filtered = a[a]
|
||||||
|
a_filtered[:] = False
|
||||||
|
a_filtered[unique_indices] = True
|
||||||
|
a[a] = a_filtered
|
||||||
|
|
||||||
|
is_case = np.zeros_like(x, dtype=bool)
|
||||||
|
is_case[:len(wk[0])] = True
|
||||||
|
|
||||||
|
control = x[~is_case & a]
|
||||||
|
case = x[is_case & a]
|
||||||
|
|
||||||
if len(control) == 0 or len(case) == 0:
|
if len(control) == 0 or len(case) == 0:
|
||||||
continue
|
continue
|
||||||
@@ -282,6 +285,35 @@ def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5),
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def process_chunk(disease_chunk_idx, diseases_chunk, d100k, p100k, pred_idx_precompute, age_groups, offset, n_bootstrap):
|
||||||
|
all_aucs = []
|
||||||
|
for sex, sex_idx in [("female", 2), ("male", 3)]:
|
||||||
|
sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy()
|
||||||
|
p_sex = p100k[sex_mask]
|
||||||
|
d100k_sex = [d_.cpu().detach().numpy()[sex_mask] for d_ in d100k]
|
||||||
|
precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy()
|
||||||
|
for j, k in tqdm(
|
||||||
|
list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}"
|
||||||
|
):
|
||||||
|
out = get_calibration_auc(
|
||||||
|
j,
|
||||||
|
k,
|
||||||
|
d100k_sex,
|
||||||
|
p_sex,
|
||||||
|
age_groups=age_groups,
|
||||||
|
offset=offset,
|
||||||
|
precomputed_idx=precomputed_idx_subset,
|
||||||
|
n_bootstrap=n_bootstrap,
|
||||||
|
use_delong=True,
|
||||||
|
)
|
||||||
|
if out is None:
|
||||||
|
continue
|
||||||
|
for out_item in out:
|
||||||
|
out_item["sex"] = sex
|
||||||
|
all_aucs.append(out_item)
|
||||||
|
return all_aucs
|
||||||
|
|
||||||
|
|
||||||
# New internal function that performs the AUC evaluation pipeline.
|
# New internal function that performs the AUC evaluation pipeline.
|
||||||
def evaluate_auc_pipeline(
|
def evaluate_auc_pipeline(
|
||||||
model,
|
model,
|
||||||
@@ -293,11 +325,12 @@ def evaluate_auc_pipeline(
|
|||||||
disease_chunk_size=200,
|
disease_chunk_size=200,
|
||||||
age_groups=np.arange(40, 80, 5),
|
age_groups=np.arange(40, 80, 5),
|
||||||
offset=0.1,
|
offset=0.1,
|
||||||
batch_size=128,
|
batch_size=256,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
seed=1337,
|
seed=1337,
|
||||||
n_bootstrap=1,
|
n_bootstrap=1,
|
||||||
meta_info={},
|
meta_info={},
|
||||||
|
n_jobs=-1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Runs the AUC evaluation pipeline.
|
Runs the AUC evaluation pipeline.
|
||||||
@@ -316,6 +349,7 @@ def evaluate_auc_pipeline(
|
|||||||
device (str): Device identifier.
|
device (str): Device identifier.
|
||||||
seed (int): Random seed for reproducibility.
|
seed (int): Random seed for reproducibility.
|
||||||
n_bootstrap (int): Number of bootstrap samples. (1 for no bootstrap)
|
n_bootstrap (int): Number of bootstrap samples. (1 for no bootstrap)
|
||||||
|
n_jobs (int): Number of parallel jobs to run.
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (df_auc_unpooled, df_auc, df_both) DataFrames.
|
tuple: (df_auc_unpooled, df_auc, df_both) DataFrames.
|
||||||
"""
|
"""
|
||||||
@@ -337,51 +371,27 @@ def evaluate_auc_pipeline(
|
|||||||
# Precompute prediction indices for calibration
|
# Precompute prediction indices for calibration
|
||||||
pred_idx_precompute = (d100k[1][:, :, np.newaxis] < d100k[3][:, np.newaxis, :] - offset).sum(1) - 1
|
pred_idx_precompute = (d100k[1][:, :, np.newaxis] < d100k[3][:, np.newaxis, :] - offset).sum(1) - 1
|
||||||
|
|
||||||
all_aucs = []
|
p100k = []
|
||||||
tqdm_options = {"desc": "Processing disease chunks", "total": len(diseases_chunks)}
|
model.to(device)
|
||||||
for disease_chunk_idx, diseases_chunk in tqdm(enumerate(diseases_chunks), **tqdm_options):
|
with torch.no_grad():
|
||||||
p100k = []
|
for dd in tqdm(
|
||||||
model.to(device)
|
zip(*[torch.split(x, batch_size) for x in d100k]),
|
||||||
with torch.no_grad():
|
desc=f"Model inference",
|
||||||
# Process the evaluation data in batches
|
total=d100k[0].shape[0] // batch_size + 1,
|
||||||
for dd in tqdm(
|
):
|
||||||
zip(*[torch.split(x, batch_size) for x in d100k]),
|
dd = [x.to(device) for x in dd]
|
||||||
desc=f"Model inference, chunk {disease_chunk_idx}",
|
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
|
||||||
total=d100k[0].shape[0] // batch_size + 1,
|
p100k.append(outputs.astype("float16"))
|
||||||
):
|
p100k = np.vstack(p100k)
|
||||||
dd = [x.to(device) for x in dd]
|
|
||||||
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
|
results = Parallel(n_jobs=n_jobs)(
|
||||||
# Keep only the columns corresponding to the current disease chunk
|
delayed(process_chunk)(
|
||||||
p100k.append(outputs[:, :, diseases_chunk].astype("float16")) # enough to store logits, but not rates
|
disease_chunk_idx, diseases_chunk, d100k, p100k[:, :, diseases_chunk], pred_idx_precompute, age_groups, offset, n_bootstrap
|
||||||
p100k = np.vstack(p100k)
|
)
|
||||||
|
for disease_chunk_idx, diseases_chunk in enumerate(diseases_chunks)
|
||||||
|
)
|
||||||
|
|
||||||
# Loop over each disease (token) in the current chunk, sexes separately
|
all_aucs = [item for sublist in results for item in sublist]
|
||||||
for sex, sex_idx in [("female", 2), ("male", 3)]:
|
|
||||||
sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy()
|
|
||||||
p_sex = p100k[sex_mask]
|
|
||||||
d100k_sex = [d_[sex_mask].cpu().detach().numpy() for d_ in d100k]
|
|
||||||
precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy()
|
|
||||||
for j, k in tqdm(
|
|
||||||
list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}"
|
|
||||||
):
|
|
||||||
# Get calibration AUC for the current disease token.
|
|
||||||
out = get_calibration_auc(
|
|
||||||
j,
|
|
||||||
k,
|
|
||||||
d100k_sex,
|
|
||||||
p_sex,
|
|
||||||
age_groups=age_groups,
|
|
||||||
offset=offset,
|
|
||||||
precomputed_idx=precomputed_idx_subset,
|
|
||||||
n_bootstrap=n_bootstrap,
|
|
||||||
use_delong=True,
|
|
||||||
)
|
|
||||||
if out is None:
|
|
||||||
# print(f"No data for disease {k} and sex {sex}")
|
|
||||||
continue
|
|
||||||
for out_item in out:
|
|
||||||
out_item["sex"] = sex
|
|
||||||
all_aucs.append(out_item)
|
|
||||||
|
|
||||||
df_auc_unpooled = pd.DataFrame(all_aucs)
|
df_auc_unpooled = pd.DataFrame(all_aucs)
|
||||||
|
|
||||||
@@ -427,6 +437,7 @@ def main():
|
|||||||
# Optional filtering/chunking parameters:
|
# Optional filtering/chunking parameters:
|
||||||
parser.add_argument("--filter_min_total", type=int, default=100, help="Minimum total count to filter tokens")
|
parser.add_argument("--filter_min_total", type=int, default=100, help="Minimum total count to filter tokens")
|
||||||
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
|
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
|
||||||
|
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
output_path = './'
|
output_path = './'
|
||||||
@@ -475,6 +486,7 @@ def main():
|
|||||||
device=device,
|
device=device,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
n_bootstrap=args.n_bootstrap,
|
n_bootstrap=args.n_bootstrap,
|
||||||
|
n_jobs=args.n_jobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -2,3 +2,4 @@ torch
|
|||||||
numpy
|
numpy
|
||||||
tqdm
|
tqdm
|
||||||
matplotlib
|
matplotlib
|
||||||
|
joblib
|
||||||
|
Reference in New Issue
Block a user