Refactor: Remove Jupyter Notebook cell markers

This commit is contained in:
2025-10-18 13:32:26 +08:00
parent dbc3000192
commit 14865ac5b6
2 changed files with 0 additions and 67 deletions

View File

@@ -25,7 +25,6 @@
#
#
# In[2]:
import os
@@ -63,7 +62,6 @@ dark_female = '#7A00BF'
delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')
# In[3]:
# Delphi is capable of predicting the disease risk for 1,256 diseases from ICD-10 plus death.
@@ -75,7 +73,6 @@ delphi_labels.iloc[diseases_of_interest][['name', 'ICD-10 Chapter (short)']]
# ## Load model
# In[4]:
out_dir = 'Delphi-2M'
@@ -88,7 +85,6 @@ torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# In[5]:
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
@@ -104,7 +100,6 @@ model = model.to(device)
checkpoint['model_args']
# In[6]:
# Let's try to use the loaded model to extrapolate a partial health trajectory.
@@ -133,7 +128,6 @@ example_health_trajectory = [
example_health_trajectory = [(a, b * 365.25) for a,b in example_health_trajectory]
# In[ ]:
max_new_tokens = 100
@@ -175,7 +169,6 @@ with torch.no_grad():
#
# No-event tokens eliminate long time intervals without tokens, which are typical for younger ages, when people generally have fewer diseases and therefore less medical records. Transformers predict the text token probability distribution only at the time of currently observed tokens, hence, no-event tokens can also be inserted during inference to obtain the predicted disease risk at any given time of interest.
# In[8]:
from utils import get_batch, get_p2i
@@ -194,7 +187,6 @@ dataset_subset_size = 2000 # len(val_p2i) # can be set to smaller number (e.g. 2
# ## Calibration of predicted times
# In[9]:
# Fetch a bit of data and calculate future disease rates from it
@@ -204,13 +196,11 @@ with torch.no_grad():
t = (d[3]-d[1])[:,:].cpu().numpy().squeeze()
# In[10]:
from scipy.special import logsumexp
# Calculate expected waiting times from model predictions using competing exponentials theory
# In Delphi's framework, each possible event has an exponential distribution with rate λᵢ = exp(logits[i])
# The expected time until any event occurs is 1/sum(λᵢ) = 1/exp(logsumexp(logits))
# logsumexp provides numerical stability vs. calculating exp(logits) directly
@@ -250,7 +240,6 @@ plt.gca().tick_params(length=1.15, width=0.3, labelsize=8, grid_alpha=0.0, grid_
# ## Incidence
# In[11]:
## Load large chunk of data
@@ -267,7 +256,6 @@ d = get_batch(range(subset_size), val, val_p2i,
device=device, padding='random')
# In[12]:
# 2 is female token, 3 is male token
@@ -277,7 +265,6 @@ is_female = (d[0] == 2).any(axis=1).cpu().numpy()
has_gender = is_male | is_female
# In[13]:
# lets split the large data chanks to smaller batches and calculate the logits for the whole dataset
@@ -295,7 +282,6 @@ d = [d_.cpu() for d_ in d]
# ### Age-sex incidence baseline
# In[14]:
# calculate disease incidence rates for each disease, given age and sex
@@ -334,7 +320,6 @@ females_in_ukb = np.cumsum(np.histogram((females[ukb_condition, 1]/365.25).astyp
#
# Bright dots are often located above the population average rates, which indicates that Delphi correctly captures the elevated disease risk for such participants.
# In[15]:
def plot_age_incidence(ix, d, p, highlight_idx=0):
@@ -444,7 +429,6 @@ def plot_age_incidence(ix, d, p, highlight_idx=0):
axf[i].legend(loc='center left', bbox_to_anchor=(1.05, 0.5))
# In[16]:
plot_age_incidence(diseases_of_interest, d, p, highlight_idx=0)
@@ -463,7 +447,6 @@ plt.show()
# 5. Plot the calibration curve
#
# In[17]:
def auc(x1, x2):
@@ -477,7 +460,6 @@ def auc(x1, x2):
return U1 / n1 / n2
# In[18]:
d100k = get_batch(range(dataset_subset_size), val, val_p2i,
@@ -485,7 +467,6 @@ d100k = get_batch(range(dataset_subset_size), val, val_p2i,
device=device, padding='random')
# In[19]:
p100k = []
@@ -497,7 +478,6 @@ with torch.no_grad():
p100k = np.vstack(p100k)
# In[20]:
import scipy
@@ -626,7 +606,6 @@ def plot_calibration(disease_idx, data, logits, offset = 365.25, age_groups=rang
return out
# In[21]:
out = []
@@ -649,7 +628,6 @@ for j, k in enumerate(diseases_of_interest):
plt.show()
# In[22]:
# the same calibrations curves as above, but a more compact version
@@ -693,7 +671,6 @@ plt.show()
# 4. Calculate the AUC using Delphi disease rates as predictors
# 5. (Optional) Use DeLong's method (recommended) or bootstrap to calculate the variance of the AUC
# In[23]:
from evaluate_auc import get_calibration_auc, evaluate_auc_pipeline
@@ -719,7 +696,6 @@ auc_inputs = {
}
# In[24]:
all_aucs = []
@@ -745,7 +721,6 @@ for disease_idx_batch, disease_idx in tqdm(enumerate(diseases_of_interest), tota
all_aucs.append(out_item)
# In[25]:
# this df contains AUC calculations for all diseases, sexes, and age groups
@@ -779,7 +754,6 @@ auc_df = auc_df.merge(delphi_labels[['name', 'index']], left_on='token', right_o
auc_df
# In[26]:
plt.figure(figsize=(7, 5))
@@ -817,7 +791,6 @@ plt.tight_layout()
#
# Therefore, we well use precomputed results here.
# In[27]:
# df_auc_unpooled_merged, df_auc_merged = evaluate_auc_pipeline(model,
@@ -830,14 +803,12 @@ plt.tight_layout()
# )
# In[28]:
df_auc_all_diseases = pd.read_csv('supplementary/delphi_auc.csv')
df_auc_all_diseases['mean_auc'] = df_auc_all_diseases[['AUC Female, (no gap)', 'AUC Male, (no gap)']].mean(axis=1)
# In[29]:
plt.figure(figsize=(7, 4))
@@ -852,7 +823,6 @@ plt.ylabel('AUC')
plt.show()
# In[30]:
import matplotlib.pyplot as plt
@@ -909,7 +879,6 @@ plt.grid(axis='x', visible=False)
plt.show()
# In[31]:
import matplotlib.pyplot as plt
@@ -984,7 +953,6 @@ plt.show()
#
# Attention maps can be used for interpretability, however for a more robust interpretation, we suggest using SHAP values (`shap_analysis.ipynb`).
# In[32]:
d = get_batch([0], val, val_p2i, select='left', block_size=model.config.block_size, device=device)
@@ -1002,7 +970,6 @@ for i in range(att.shape[0]):
plt.tight_layout()
# In[33]:
d = get_batch(range(dataset_subset_size), val, val_p2i,
@@ -1018,7 +985,6 @@ att.shape
#
# Generally, tokens tend to lose most of their importance pretty quickly. High attention for the most recent token in the trajectory is likely due to this tokens being used by the model to estimate the current age of the patient, which is a very important predictor for the overall disease risk.
# In[34]:
import textwrap
@@ -1045,14 +1011,12 @@ for i in range(len(w[0])):
#
# We see that diseases cluster by their ICD-10 chapter - which is interesting, because the model had no knowledge about the ICD-10 hierarchy during training; all diseases were treated equally.
# In[35]:
import umap
import matplotlib as mpl
# In[36]:
wte = model.transformer.wte.weight.cpu().detach().numpy()
@@ -1064,7 +1028,6 @@ u = u0 - np.median(u0, axis=0)
u = - u
# In[37]:
def remove_ticks(ax):
@@ -1080,7 +1043,6 @@ def remove_ticks(ax):
tick.tick2line.set_visible(False)
# In[38]:
labels_all = pd.read_csv('delphi_labels_chapters_colours_icd.csv')
@@ -1094,7 +1056,6 @@ short_names_present = [i for i in short_names if i in labels_non_technical['ICD-
color_mapping_short = {k: v for k, v in labels_all[['ICD-10 Chapter (short)', 'color']].values}
# In[39]:
import seaborn as sns

View File

@@ -7,7 +7,6 @@
#
# Delphi is a generative autoregressive model that not only predicts the future disease rates, but also sample entire disease trajectories one step at a time.
#
# In this notebook, we will use SHAP (SHapley Additive exPlanations) framework to analyse which interaction between diseases that Delphi learned from data and how these interaction influence its predicitons.
#
# Let's start by looking at what SHAP values mean:
#
@@ -31,9 +30,7 @@
#
# Without speaking about causality, we can assume that there is *some* connection between brain cancer and death risk. SHAP framework allows using such masking to systematically assess the contribution of each token to the prediction. We can perform this analysis for all trajectories in the dataset and evaluate how, on average, a given disease influences the risk of any other disease.
#
# In case of Delphi, masking means replacing a disease token with "no event" token for all input tokens, except for the sex token that is inverted.
# In[2]:
import os
@@ -64,7 +61,6 @@ delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')
# ## Load model
# In[3]:
out_dir = 'Delphi-2M'
@@ -81,7 +77,6 @@ device_type = 'cuda' if 'cuda' in device else 'cpu'
dtype = {'float32': torch.float32, 'float64': torch.float64, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
# In[4]:
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
@@ -97,7 +92,6 @@ model = model.to(device)
# ## Load data
# In[5]:
from utils import get_batch, get_p2i
@@ -109,7 +103,6 @@ train_p2i = get_p2i(train)
val_p2i = get_p2i(val)
# In[6]:
# define a random example health trajectory
@@ -140,7 +133,6 @@ person = [(a, b * 365.25) for a,b in person]
# ### Individual SHAP values
# In[7]:
# define helper functions
@@ -170,7 +162,6 @@ def get_person(idx):
return person, y, time[0][-1]
# In[8]:
from utils import shap_custom_tokenizer, shap_model_creator
@@ -191,13 +182,11 @@ shap_values = explainer([' '.join(list(map(lambda x: str(token_to_id[x]), person
shap_values.data = np.array([list(map(lambda x: f"{delphi_labels['name'].values[token_to_id[x[0]]]}({x[1]/365:.1f} years) ", person_to_process))])
# In[9]:
out = shap.plots.text(shap_values, display=True) # sometimes this interactive plot can't be rendered well (eg in VS Code, feel free to skip it)
# In[10]:
# SHAP values can be interpreted as how much each input token changes predicted logit corresponding to a particular disease.
@@ -221,7 +210,6 @@ with plt.style.context('default'):
#
# The small synthetic dataset is not enough to properly run following part; if you have access to the full dataset, run `shap-agg-eval.py` to evaluate SHAP values for the entire dataset.
# In[11]:
import pickle
@@ -233,7 +221,6 @@ all_tokens = shap_pkl['tokens']
all_values = shap_pkl['values']
# In[12]:
import pandas as pd
@@ -245,7 +232,6 @@ df_shap = pd.DataFrame(all_values)
df_shap['token'] = all_tokens.astype('int')
# In[13]:
token_count_dict = df_shap['token'].value_counts().sort_index().to_dict()
@@ -266,7 +252,6 @@ df_shap_agg = df_shap[df_shap['token'].apply(lambda x: token_count_dict[x] > N_m
#
# Let's see which diseases increase disease risk the most and also which diseases are most influenced by being a heavy smoker.
# In[14]:
import matplotlib.pyplot as plt
@@ -322,7 +307,6 @@ plot_shap_distribution(
)
# In[15]:
target_token = 9
@@ -362,7 +346,6 @@ plot_shap_distribution(
#
# Now, we will aggregate SHAP values within the pairs, additionally separating them by the time between the "predictor" and "predicted" diseases.
# In[16]:
d = get_batch(range(len(np.unique(shap_pkl['people']))), val, val_p2i,
@@ -370,7 +353,6 @@ d = get_batch(range(len(np.unique(shap_pkl['people']))), val, val_p2i,
device='cpu', padding='regular')
# In[17]:
has_gender = torch.isin(d[0], torch.tensor([2, 3])).any(dim=1).numpy()
@@ -378,7 +360,6 @@ is_male = torch.isin(d[0], torch.tensor([3])).any(dim=1).numpy()
is_female = torch.isin(d[0], torch.tensor([2])).any(dim=1).numpy()
# In[18]:
def get_person(idx):
@@ -394,7 +375,6 @@ def get_person(idx):
return person, y, time[0][-1]
# In[19]:
# the shap result pickle does not contain time, so we need to add it
@@ -419,7 +399,6 @@ for p in tqdm(np.unique(shap_pkl['people'])):
assert len(ages) == len(df_shap)
# In[20]:
all_tokens = shap_pkl['tokens']
@@ -441,7 +420,6 @@ df_shap = df_shap[df_shap['reg_time_years'] > 0]
token_count_dict = df_shap['token'].value_counts().sort_index().to_dict()
# In[22]:
import numpy as np
@@ -458,7 +436,6 @@ def bins_avg(x, y, grid_size=3):
return bin_edges, bin_avgs
# In[23]:
tokens_of_interest = [46, 95, 1168, 1188, 173, 214, 305, 505, 584]
@@ -508,7 +485,6 @@ for num_g, token_group in enumerate(np.array_split(tokens_of_interest, n_groups)
#
# Let's plot two separate heatmaps, one for the cases where the "predictor" disease occured in the past 5 years (with the "predicted disease being the reference) and one for the cases where it occured more than 10 years ago.
# In[24]:
N_min = 5
@@ -527,7 +503,6 @@ df_shap_agg_below_5y = df_shap[df_shap['token'].apply(lambda x: x in columns_mor
df_shap_agg_over_10y = df_shap[df_shap['token'].apply(lambda x: x in columns_more_N) & (df_shap['Time, years'] > 10)].groupby('token').mean()[columns_more_N]
# In[25]:
from matplotlib.colors import LogNorm
@@ -612,18 +587,15 @@ def plot_full_shap_heatmap(cur_df, title):
plt.show()
# In[26]:
plot_full_shap_heatmap(df_shap_agg_below_5y, 'Influence of tokens from below 5 years,\nrisk increase, folds')
# In[27]:
plot_full_shap_heatmap(df_shap_agg_over_10y, 'Influence of tokens from above 10 years,\nrisk increase, folds')
# Interestingly, the resulting heatmap has a block diagonal structure, meaning that within a chapter the interactions between diseases tend to be stronger than between chapters.
#
# The second ("above 10 years") heatmap is also more pale, meaning that most of the disease-disease interactions get weaker over time.