diff --git a/evaluate_auc.py b/evaluate_auc.py index 7329ec5..7dfc6c1 100644 --- a/evaluate_auc.py +++ b/evaluate_auc.py @@ -9,6 +9,7 @@ import numpy as np import argparse from utils import load_model, get_batch, PatientEventDataset from pathlib import Path +from joblib import Parallel, delayed def auc(x1, x2): @@ -214,7 +215,7 @@ def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5), return None # For controls, we need to exclude cases with disease k - wc = np.where((d[2] != k) * (~(d[2] == k).any(-1))[..., None]) + wc = np.where((d[2] != k) & (~(d[2] == k).any(-1))[..., None]) wall = (np.concatenate([wk[0], wc[0]]), np.concatenate([wk[1], wc[1]])) # All cases and controls @@ -224,36 +225,38 @@ def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5), else: pred_idx = precomputed_idx[wall] # It's actually much faster to precompute this + valid_indices = pred_idx != -1 + pred_idx = pred_idx[valid_indices] + wall = (wall[0][valid_indices], wall[1][valid_indices]) + z = d[1][(wall[0], pred_idx)] # Times of the tokens for prediction - z = z[pred_idx != -1] - zk = d[3][wall] # Target times - zk = zk[pred_idx != -1] - # x = np.exp(p[..., j][(wall[0], pred_idx)]) * 365.25 - # x = 1 - np.exp(-x * age_step) # the function is monotinic, so we don't need to do this for the AUC x = p[..., j][(wall[0], pred_idx)] - x = x[pred_idx != -1] - wk = (wk[0][pred_idx[: len(wk[0])] != -1], wk[1][pred_idx[: len(wk[0])] != -1]) - p_idx = wall[0][pred_idx != -1] + p_idx = wall[0] out = [] for i, aa in enumerate(age_groups): - a = np.logical_and(z / 365.25 >= aa, z / 365.25 < aa + age_step) - # Optionally, add extra filtering on the time difference, for example: - # a *= (zk - z < 365.25) - selected_groups = p_idx[a] - perm = np.random.permutation(len(selected_groups)) - _, indices = np.unique(selected_groups[perm], return_index=True) - indices = perm[indices] - selected = np.zeros(np.sum(a), dtype=bool) - selected[indices] = True - a[a] = selected + a = (z / 365.25 >= aa) & (z / 365.25 < aa + age_step) + + if not np.any(a): + continue - control = x[len(wk[0]) :][a[len(wk[0]) :]] - case = x[: len(wk[0])][a[: len(wk[0])]] + selected_groups = p_idx[a] + _, unique_indices = np.unique(selected_groups, return_index=True) + + a_filtered = a[a] + a_filtered[:] = False + a_filtered[unique_indices] = True + a[a] = a_filtered + + is_case = np.zeros_like(x, dtype=bool) + is_case[:len(wk[0])] = True + + control = x[~is_case & a] + case = x[is_case & a] if len(control) == 0 or len(case) == 0: continue @@ -282,6 +285,35 @@ def get_calibration_auc(j, k, d, p, offset=365.25, age_groups=range(45, 80, 5), return out +def process_chunk(disease_chunk_idx, diseases_chunk, d100k, p100k, pred_idx_precompute, age_groups, offset, n_bootstrap): + all_aucs = [] + for sex, sex_idx in [("female", 2), ("male", 3)]: + sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy() + p_sex = p100k[sex_mask] + d100k_sex = [d_.cpu().detach().numpy()[sex_mask] for d_ in d100k] + precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy() + for j, k in tqdm( + list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}" + ): + out = get_calibration_auc( + j, + k, + d100k_sex, + p_sex, + age_groups=age_groups, + offset=offset, + precomputed_idx=precomputed_idx_subset, + n_bootstrap=n_bootstrap, + use_delong=True, + ) + if out is None: + continue + for out_item in out: + out_item["sex"] = sex + all_aucs.append(out_item) + return all_aucs + + # New internal function that performs the AUC evaluation pipeline. def evaluate_auc_pipeline( model, @@ -293,11 +325,12 @@ def evaluate_auc_pipeline( disease_chunk_size=200, age_groups=np.arange(40, 80, 5), offset=0.1, - batch_size=128, + batch_size=256, device="cpu", seed=1337, n_bootstrap=1, meta_info={}, + n_jobs=-1, ): """ Runs the AUC evaluation pipeline. @@ -316,6 +349,7 @@ def evaluate_auc_pipeline( device (str): Device identifier. seed (int): Random seed for reproducibility. n_bootstrap (int): Number of bootstrap samples. (1 for no bootstrap) + n_jobs (int): Number of parallel jobs to run. Returns: tuple: (df_auc_unpooled, df_auc, df_both) DataFrames. """ @@ -337,51 +371,27 @@ def evaluate_auc_pipeline( # Precompute prediction indices for calibration pred_idx_precompute = (d100k[1][:, :, np.newaxis] < d100k[3][:, np.newaxis, :] - offset).sum(1) - 1 - all_aucs = [] - tqdm_options = {"desc": "Processing disease chunks", "total": len(diseases_chunks)} - for disease_chunk_idx, diseases_chunk in tqdm(enumerate(diseases_chunks), **tqdm_options): - p100k = [] - model.to(device) - with torch.no_grad(): - # Process the evaluation data in batches - for dd in tqdm( - zip(*[torch.split(x, batch_size) for x in d100k]), - desc=f"Model inference, chunk {disease_chunk_idx}", - total=d100k[0].shape[0] // batch_size + 1, - ): - dd = [x.to(device) for x in dd] - outputs = model(dd[0], dd[1]).cpu().detach().numpy() - # Keep only the columns corresponding to the current disease chunk - p100k.append(outputs[:, :, diseases_chunk].astype("float16")) # enough to store logits, but not rates - p100k = np.vstack(p100k) + p100k = [] + model.to(device) + with torch.no_grad(): + for dd in tqdm( + zip(*[torch.split(x, batch_size) for x in d100k]), + desc=f"Model inference", + total=d100k[0].shape[0] // batch_size + 1, + ): + dd = [x.to(device) for x in dd] + outputs = model(dd[0], dd[1]).cpu().detach().numpy() + p100k.append(outputs.astype("float16")) + p100k = np.vstack(p100k) + + results = Parallel(n_jobs=n_jobs)( + delayed(process_chunk)( + disease_chunk_idx, diseases_chunk, d100k, p100k[:, :, diseases_chunk], pred_idx_precompute, age_groups, offset, n_bootstrap + ) + for disease_chunk_idx, diseases_chunk in enumerate(diseases_chunks) + ) - # Loop over each disease (token) in the current chunk, sexes separately - for sex, sex_idx in [("female", 2), ("male", 3)]: - sex_mask = ((d100k[0] == sex_idx).sum(1) > 0).cpu().detach().numpy() - p_sex = p100k[sex_mask] - d100k_sex = [d_[sex_mask].cpu().detach().numpy() for d_ in d100k] - precomputed_idx_subset = pred_idx_precompute[sex_mask].cpu().detach().numpy() - for j, k in tqdm( - list(enumerate(diseases_chunk)), desc=f"Processing diseases in chunk {disease_chunk_idx}, {sex}" - ): - # Get calibration AUC for the current disease token. - out = get_calibration_auc( - j, - k, - d100k_sex, - p_sex, - age_groups=age_groups, - offset=offset, - precomputed_idx=precomputed_idx_subset, - n_bootstrap=n_bootstrap, - use_delong=True, - ) - if out is None: - # print(f"No data for disease {k} and sex {sex}") - continue - for out_item in out: - out_item["sex"] = sex - all_aucs.append(out_item) + all_aucs = [item for sublist in results for item in sublist] df_auc_unpooled = pd.DataFrame(all_aucs) @@ -427,6 +437,7 @@ def main(): # Optional filtering/chunking parameters: parser.add_argument("--filter_min_total", type=int, default=100, help="Minimum total count to filter tokens") parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases") + parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run") args = parser.parse_args() output_path = './' @@ -475,6 +486,7 @@ def main(): device=device, seed=seed, n_bootstrap=args.n_bootstrap, + n_jobs=args.n_jobs, ) diff --git a/requirements.txt b/requirements.txt index 534d633..e01d1d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ torch numpy tqdm matplotlib +joblib