Refactor: Remove Jupyter Notebook cell markers

This commit is contained in:
2025-10-18 13:32:26 +08:00
parent dbc3000192
commit 14865ac5b6
2 changed files with 0 additions and 67 deletions

View File

@@ -25,7 +25,6 @@
# #
# #
# In[2]:
import os import os
@@ -63,7 +62,6 @@ dark_female = '#7A00BF'
delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv') delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')
# In[3]:
# Delphi is capable of predicting the disease risk for 1,256 diseases from ICD-10 plus death. # Delphi is capable of predicting the disease risk for 1,256 diseases from ICD-10 plus death.
@@ -75,7 +73,6 @@ delphi_labels.iloc[diseases_of_interest][['name', 'ICD-10 Chapter (short)']]
# ## Load model # ## Load model
# In[4]:
out_dir = 'Delphi-2M' out_dir = 'Delphi-2M'
@@ -88,7 +85,6 @@ torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
# In[5]:
ckpt_path = os.path.join(out_dir, 'ckpt.pt') ckpt_path = os.path.join(out_dir, 'ckpt.pt')
@@ -104,7 +100,6 @@ model = model.to(device)
checkpoint['model_args'] checkpoint['model_args']
# In[6]:
# Let's try to use the loaded model to extrapolate a partial health trajectory. # Let's try to use the loaded model to extrapolate a partial health trajectory.
@@ -133,7 +128,6 @@ example_health_trajectory = [
example_health_trajectory = [(a, b * 365.25) for a,b in example_health_trajectory] example_health_trajectory = [(a, b * 365.25) for a,b in example_health_trajectory]
# In[ ]:
max_new_tokens = 100 max_new_tokens = 100
@@ -175,7 +169,6 @@ with torch.no_grad():
# #
# No-event tokens eliminate long time intervals without tokens, which are typical for younger ages, when people generally have fewer diseases and therefore less medical records. Transformers predict the text token probability distribution only at the time of currently observed tokens, hence, no-event tokens can also be inserted during inference to obtain the predicted disease risk at any given time of interest. # No-event tokens eliminate long time intervals without tokens, which are typical for younger ages, when people generally have fewer diseases and therefore less medical records. Transformers predict the text token probability distribution only at the time of currently observed tokens, hence, no-event tokens can also be inserted during inference to obtain the predicted disease risk at any given time of interest.
# In[8]:
from utils import get_batch, get_p2i from utils import get_batch, get_p2i
@@ -194,7 +187,6 @@ dataset_subset_size = 2000 # len(val_p2i) # can be set to smaller number (e.g. 2
# ## Calibration of predicted times # ## Calibration of predicted times
# In[9]:
# Fetch a bit of data and calculate future disease rates from it # Fetch a bit of data and calculate future disease rates from it
@@ -204,13 +196,11 @@ with torch.no_grad():
t = (d[3]-d[1])[:,:].cpu().numpy().squeeze() t = (d[3]-d[1])[:,:].cpu().numpy().squeeze()
# In[10]:
from scipy.special import logsumexp from scipy.special import logsumexp
# Calculate expected waiting times from model predictions using competing exponentials theory # Calculate expected waiting times from model predictions using competing exponentials theory
# In Delphi's framework, each possible event has an exponential distribution with rate λᵢ = exp(logits[i])
# The expected time until any event occurs is 1/sum(λᵢ) = 1/exp(logsumexp(logits)) # The expected time until any event occurs is 1/sum(λᵢ) = 1/exp(logsumexp(logits))
# logsumexp provides numerical stability vs. calculating exp(logits) directly # logsumexp provides numerical stability vs. calculating exp(logits) directly
@@ -250,7 +240,6 @@ plt.gca().tick_params(length=1.15, width=0.3, labelsize=8, grid_alpha=0.0, grid_
# ## Incidence # ## Incidence
# In[11]:
## Load large chunk of data ## Load large chunk of data
@@ -267,7 +256,6 @@ d = get_batch(range(subset_size), val, val_p2i,
device=device, padding='random') device=device, padding='random')
# In[12]:
# 2 is female token, 3 is male token # 2 is female token, 3 is male token
@@ -277,7 +265,6 @@ is_female = (d[0] == 2).any(axis=1).cpu().numpy()
has_gender = is_male | is_female has_gender = is_male | is_female
# In[13]:
# lets split the large data chanks to smaller batches and calculate the logits for the whole dataset # lets split the large data chanks to smaller batches and calculate the logits for the whole dataset
@@ -295,7 +282,6 @@ d = [d_.cpu() for d_ in d]
# ### Age-sex incidence baseline # ### Age-sex incidence baseline
# In[14]:
# calculate disease incidence rates for each disease, given age and sex # calculate disease incidence rates for each disease, given age and sex
@@ -334,7 +320,6 @@ females_in_ukb = np.cumsum(np.histogram((females[ukb_condition, 1]/365.25).astyp
# #
# Bright dots are often located above the population average rates, which indicates that Delphi correctly captures the elevated disease risk for such participants. # Bright dots are often located above the population average rates, which indicates that Delphi correctly captures the elevated disease risk for such participants.
# In[15]:
def plot_age_incidence(ix, d, p, highlight_idx=0): def plot_age_incidence(ix, d, p, highlight_idx=0):
@@ -444,7 +429,6 @@ def plot_age_incidence(ix, d, p, highlight_idx=0):
axf[i].legend(loc='center left', bbox_to_anchor=(1.05, 0.5)) axf[i].legend(loc='center left', bbox_to_anchor=(1.05, 0.5))
# In[16]:
plot_age_incidence(diseases_of_interest, d, p, highlight_idx=0) plot_age_incidence(diseases_of_interest, d, p, highlight_idx=0)
@@ -463,7 +447,6 @@ plt.show()
# 5. Plot the calibration curve # 5. Plot the calibration curve
# #
# In[17]:
def auc(x1, x2): def auc(x1, x2):
@@ -477,7 +460,6 @@ def auc(x1, x2):
return U1 / n1 / n2 return U1 / n1 / n2
# In[18]:
d100k = get_batch(range(dataset_subset_size), val, val_p2i, d100k = get_batch(range(dataset_subset_size), val, val_p2i,
@@ -485,7 +467,6 @@ d100k = get_batch(range(dataset_subset_size), val, val_p2i,
device=device, padding='random') device=device, padding='random')
# In[19]:
p100k = [] p100k = []
@@ -497,7 +478,6 @@ with torch.no_grad():
p100k = np.vstack(p100k) p100k = np.vstack(p100k)
# In[20]:
import scipy import scipy
@@ -626,7 +606,6 @@ def plot_calibration(disease_idx, data, logits, offset = 365.25, age_groups=rang
return out return out
# In[21]:
out = [] out = []
@@ -649,7 +628,6 @@ for j, k in enumerate(diseases_of_interest):
plt.show() plt.show()
# In[22]:
# the same calibrations curves as above, but a more compact version # the same calibrations curves as above, but a more compact version
@@ -693,7 +671,6 @@ plt.show()
# 4. Calculate the AUC using Delphi disease rates as predictors # 4. Calculate the AUC using Delphi disease rates as predictors
# 5. (Optional) Use DeLong's method (recommended) or bootstrap to calculate the variance of the AUC # 5. (Optional) Use DeLong's method (recommended) or bootstrap to calculate the variance of the AUC
# In[23]:
from evaluate_auc import get_calibration_auc, evaluate_auc_pipeline from evaluate_auc import get_calibration_auc, evaluate_auc_pipeline
@@ -719,7 +696,6 @@ auc_inputs = {
} }
# In[24]:
all_aucs = [] all_aucs = []
@@ -745,7 +721,6 @@ for disease_idx_batch, disease_idx in tqdm(enumerate(diseases_of_interest), tota
all_aucs.append(out_item) all_aucs.append(out_item)
# In[25]:
# this df contains AUC calculations for all diseases, sexes, and age groups # this df contains AUC calculations for all diseases, sexes, and age groups
@@ -779,7 +754,6 @@ auc_df = auc_df.merge(delphi_labels[['name', 'index']], left_on='token', right_o
auc_df auc_df
# In[26]:
plt.figure(figsize=(7, 5)) plt.figure(figsize=(7, 5))
@@ -817,7 +791,6 @@ plt.tight_layout()
# #
# Therefore, we well use precomputed results here. # Therefore, we well use precomputed results here.
# In[27]:
# df_auc_unpooled_merged, df_auc_merged = evaluate_auc_pipeline(model, # df_auc_unpooled_merged, df_auc_merged = evaluate_auc_pipeline(model,
@@ -830,14 +803,12 @@ plt.tight_layout()
# ) # )
# In[28]:
df_auc_all_diseases = pd.read_csv('supplementary/delphi_auc.csv') df_auc_all_diseases = pd.read_csv('supplementary/delphi_auc.csv')
df_auc_all_diseases['mean_auc'] = df_auc_all_diseases[['AUC Female, (no gap)', 'AUC Male, (no gap)']].mean(axis=1) df_auc_all_diseases['mean_auc'] = df_auc_all_diseases[['AUC Female, (no gap)', 'AUC Male, (no gap)']].mean(axis=1)
# In[29]:
plt.figure(figsize=(7, 4)) plt.figure(figsize=(7, 4))
@@ -852,7 +823,6 @@ plt.ylabel('AUC')
plt.show() plt.show()
# In[30]:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -909,7 +879,6 @@ plt.grid(axis='x', visible=False)
plt.show() plt.show()
# In[31]:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -984,7 +953,6 @@ plt.show()
# #
# Attention maps can be used for interpretability, however for a more robust interpretation, we suggest using SHAP values (`shap_analysis.ipynb`). # Attention maps can be used for interpretability, however for a more robust interpretation, we suggest using SHAP values (`shap_analysis.ipynb`).
# In[32]:
d = get_batch([0], val, val_p2i, select='left', block_size=model.config.block_size, device=device) d = get_batch([0], val, val_p2i, select='left', block_size=model.config.block_size, device=device)
@@ -1002,7 +970,6 @@ for i in range(att.shape[0]):
plt.tight_layout() plt.tight_layout()
# In[33]:
d = get_batch(range(dataset_subset_size), val, val_p2i, d = get_batch(range(dataset_subset_size), val, val_p2i,
@@ -1018,7 +985,6 @@ att.shape
# #
# Generally, tokens tend to lose most of their importance pretty quickly. High attention for the most recent token in the trajectory is likely due to this tokens being used by the model to estimate the current age of the patient, which is a very important predictor for the overall disease risk. # Generally, tokens tend to lose most of their importance pretty quickly. High attention for the most recent token in the trajectory is likely due to this tokens being used by the model to estimate the current age of the patient, which is a very important predictor for the overall disease risk.
# In[34]:
import textwrap import textwrap
@@ -1045,14 +1011,12 @@ for i in range(len(w[0])):
# #
# We see that diseases cluster by their ICD-10 chapter - which is interesting, because the model had no knowledge about the ICD-10 hierarchy during training; all diseases were treated equally. # We see that diseases cluster by their ICD-10 chapter - which is interesting, because the model had no knowledge about the ICD-10 hierarchy during training; all diseases were treated equally.
# In[35]:
import umap import umap
import matplotlib as mpl import matplotlib as mpl
# In[36]:
wte = model.transformer.wte.weight.cpu().detach().numpy() wte = model.transformer.wte.weight.cpu().detach().numpy()
@@ -1064,7 +1028,6 @@ u = u0 - np.median(u0, axis=0)
u = - u u = - u
# In[37]:
def remove_ticks(ax): def remove_ticks(ax):
@@ -1080,7 +1043,6 @@ def remove_ticks(ax):
tick.tick2line.set_visible(False) tick.tick2line.set_visible(False)
# In[38]:
labels_all = pd.read_csv('delphi_labels_chapters_colours_icd.csv') labels_all = pd.read_csv('delphi_labels_chapters_colours_icd.csv')
@@ -1094,7 +1056,6 @@ short_names_present = [i for i in short_names if i in labels_non_technical['ICD-
color_mapping_short = {k: v for k, v in labels_all[['ICD-10 Chapter (short)', 'color']].values} color_mapping_short = {k: v for k, v in labels_all[['ICD-10 Chapter (short)', 'color']].values}
# In[39]:
import seaborn as sns import seaborn as sns

View File

@@ -7,7 +7,6 @@
# #
# Delphi is a generative autoregressive model that not only predicts the future disease rates, but also sample entire disease trajectories one step at a time. # Delphi is a generative autoregressive model that not only predicts the future disease rates, but also sample entire disease trajectories one step at a time.
# #
# In this notebook, we will use SHAP (SHapley Additive exPlanations) framework to analyse which interaction between diseases that Delphi learned from data and how these interaction influence its predicitons.
# #
# Let's start by looking at what SHAP values mean: # Let's start by looking at what SHAP values mean:
# #
@@ -31,9 +30,7 @@
# #
# Without speaking about causality, we can assume that there is *some* connection between brain cancer and death risk. SHAP framework allows using such masking to systematically assess the contribution of each token to the prediction. We can perform this analysis for all trajectories in the dataset and evaluate how, on average, a given disease influences the risk of any other disease. # Without speaking about causality, we can assume that there is *some* connection between brain cancer and death risk. SHAP framework allows using such masking to systematically assess the contribution of each token to the prediction. We can perform this analysis for all trajectories in the dataset and evaluate how, on average, a given disease influences the risk of any other disease.
# #
# In case of Delphi, masking means replacing a disease token with "no event" token for all input tokens, except for the sex token that is inverted.
# In[2]:
import os import os
@@ -64,7 +61,6 @@ delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')
# ## Load model # ## Load model
# In[3]:
out_dir = 'Delphi-2M' out_dir = 'Delphi-2M'
@@ -81,7 +77,6 @@ device_type = 'cuda' if 'cuda' in device else 'cpu'
dtype = {'float32': torch.float32, 'float64': torch.float64, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] dtype = {'float32': torch.float32, 'float64': torch.float64, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
# In[4]:
ckpt_path = os.path.join(out_dir, 'ckpt.pt') ckpt_path = os.path.join(out_dir, 'ckpt.pt')
@@ -97,7 +92,6 @@ model = model.to(device)
# ## Load data # ## Load data
# In[5]:
from utils import get_batch, get_p2i from utils import get_batch, get_p2i
@@ -109,7 +103,6 @@ train_p2i = get_p2i(train)
val_p2i = get_p2i(val) val_p2i = get_p2i(val)
# In[6]:
# define a random example health trajectory # define a random example health trajectory
@@ -140,7 +133,6 @@ person = [(a, b * 365.25) for a,b in person]
# ### Individual SHAP values # ### Individual SHAP values
# In[7]:
# define helper functions # define helper functions
@@ -170,7 +162,6 @@ def get_person(idx):
return person, y, time[0][-1] return person, y, time[0][-1]
# In[8]:
from utils import shap_custom_tokenizer, shap_model_creator from utils import shap_custom_tokenizer, shap_model_creator
@@ -191,13 +182,11 @@ shap_values = explainer([' '.join(list(map(lambda x: str(token_to_id[x]), person
shap_values.data = np.array([list(map(lambda x: f"{delphi_labels['name'].values[token_to_id[x[0]]]}({x[1]/365:.1f} years) ", person_to_process))]) shap_values.data = np.array([list(map(lambda x: f"{delphi_labels['name'].values[token_to_id[x[0]]]}({x[1]/365:.1f} years) ", person_to_process))])
# In[9]:
out = shap.plots.text(shap_values, display=True) # sometimes this interactive plot can't be rendered well (eg in VS Code, feel free to skip it) out = shap.plots.text(shap_values, display=True) # sometimes this interactive plot can't be rendered well (eg in VS Code, feel free to skip it)
# In[10]:
# SHAP values can be interpreted as how much each input token changes predicted logit corresponding to a particular disease. # SHAP values can be interpreted as how much each input token changes predicted logit corresponding to a particular disease.
@@ -221,7 +210,6 @@ with plt.style.context('default'):
# #
# The small synthetic dataset is not enough to properly run following part; if you have access to the full dataset, run `shap-agg-eval.py` to evaluate SHAP values for the entire dataset. # The small synthetic dataset is not enough to properly run following part; if you have access to the full dataset, run `shap-agg-eval.py` to evaluate SHAP values for the entire dataset.
# In[11]:
import pickle import pickle
@@ -233,7 +221,6 @@ all_tokens = shap_pkl['tokens']
all_values = shap_pkl['values'] all_values = shap_pkl['values']
# In[12]:
import pandas as pd import pandas as pd
@@ -245,7 +232,6 @@ df_shap = pd.DataFrame(all_values)
df_shap['token'] = all_tokens.astype('int') df_shap['token'] = all_tokens.astype('int')
# In[13]:
token_count_dict = df_shap['token'].value_counts().sort_index().to_dict() token_count_dict = df_shap['token'].value_counts().sort_index().to_dict()
@@ -266,7 +252,6 @@ df_shap_agg = df_shap[df_shap['token'].apply(lambda x: token_count_dict[x] > N_m
# #
# Let's see which diseases increase disease risk the most and also which diseases are most influenced by being a heavy smoker. # Let's see which diseases increase disease risk the most and also which diseases are most influenced by being a heavy smoker.
# In[14]:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -322,7 +307,6 @@ plot_shap_distribution(
) )
# In[15]:
target_token = 9 target_token = 9
@@ -362,7 +346,6 @@ plot_shap_distribution(
# #
# Now, we will aggregate SHAP values within the pairs, additionally separating them by the time between the "predictor" and "predicted" diseases. # Now, we will aggregate SHAP values within the pairs, additionally separating them by the time between the "predictor" and "predicted" diseases.
# In[16]:
d = get_batch(range(len(np.unique(shap_pkl['people']))), val, val_p2i, d = get_batch(range(len(np.unique(shap_pkl['people']))), val, val_p2i,
@@ -370,7 +353,6 @@ d = get_batch(range(len(np.unique(shap_pkl['people']))), val, val_p2i,
device='cpu', padding='regular') device='cpu', padding='regular')
# In[17]:
has_gender = torch.isin(d[0], torch.tensor([2, 3])).any(dim=1).numpy() has_gender = torch.isin(d[0], torch.tensor([2, 3])).any(dim=1).numpy()
@@ -378,7 +360,6 @@ is_male = torch.isin(d[0], torch.tensor([3])).any(dim=1).numpy()
is_female = torch.isin(d[0], torch.tensor([2])).any(dim=1).numpy() is_female = torch.isin(d[0], torch.tensor([2])).any(dim=1).numpy()
# In[18]:
def get_person(idx): def get_person(idx):
@@ -394,7 +375,6 @@ def get_person(idx):
return person, y, time[0][-1] return person, y, time[0][-1]
# In[19]:
# the shap result pickle does not contain time, so we need to add it # the shap result pickle does not contain time, so we need to add it
@@ -419,7 +399,6 @@ for p in tqdm(np.unique(shap_pkl['people'])):
assert len(ages) == len(df_shap) assert len(ages) == len(df_shap)
# In[20]:
all_tokens = shap_pkl['tokens'] all_tokens = shap_pkl['tokens']
@@ -441,7 +420,6 @@ df_shap = df_shap[df_shap['reg_time_years'] > 0]
token_count_dict = df_shap['token'].value_counts().sort_index().to_dict() token_count_dict = df_shap['token'].value_counts().sort_index().to_dict()
# In[22]:
import numpy as np import numpy as np
@@ -458,7 +436,6 @@ def bins_avg(x, y, grid_size=3):
return bin_edges, bin_avgs return bin_edges, bin_avgs
# In[23]:
tokens_of_interest = [46, 95, 1168, 1188, 173, 214, 305, 505, 584] tokens_of_interest = [46, 95, 1168, 1188, 173, 214, 305, 505, 584]
@@ -508,7 +485,6 @@ for num_g, token_group in enumerate(np.array_split(tokens_of_interest, n_groups)
# #
# Let's plot two separate heatmaps, one for the cases where the "predictor" disease occured in the past 5 years (with the "predicted disease being the reference) and one for the cases where it occured more than 10 years ago. # Let's plot two separate heatmaps, one for the cases where the "predictor" disease occured in the past 5 years (with the "predicted disease being the reference) and one for the cases where it occured more than 10 years ago.
# In[24]:
N_min = 5 N_min = 5
@@ -527,7 +503,6 @@ df_shap_agg_below_5y = df_shap[df_shap['token'].apply(lambda x: x in columns_mor
df_shap_agg_over_10y = df_shap[df_shap['token'].apply(lambda x: x in columns_more_N) & (df_shap['Time, years'] > 10)].groupby('token').mean()[columns_more_N] df_shap_agg_over_10y = df_shap[df_shap['token'].apply(lambda x: x in columns_more_N) & (df_shap['Time, years'] > 10)].groupby('token').mean()[columns_more_N]
# In[25]:
from matplotlib.colors import LogNorm from matplotlib.colors import LogNorm
@@ -612,18 +587,15 @@ def plot_full_shap_heatmap(cur_df, title):
plt.show() plt.show()
# In[26]:
plot_full_shap_heatmap(df_shap_agg_below_5y, 'Influence of tokens from below 5 years,\nrisk increase, folds') plot_full_shap_heatmap(df_shap_agg_below_5y, 'Influence of tokens from below 5 years,\nrisk increase, folds')
# In[27]:
plot_full_shap_heatmap(df_shap_agg_over_10y, 'Influence of tokens from above 10 years,\nrisk increase, folds') plot_full_shap_heatmap(df_shap_agg_over_10y, 'Influence of tokens from above 10 years,\nrisk increase, folds')
# Interestingly, the resulting heatmap has a block diagonal structure, meaning that within a chapter the interactions between diseases tend to be stronger than between chapters.
# #
# The second ("above 10 years") heatmap is also more pale, meaning that most of the disease-disease interactions get weaker over time. # The second ("above 10 years") heatmap is also more pale, meaning that most of the disease-disease interactions get weaker over time.