Refactor: Remove Jupyter Notebook cell markers

This commit is contained in:
2025-10-18 13:32:26 +08:00
parent dbc3000192
commit 14865ac5b6
2 changed files with 0 additions and 67 deletions

View File

@@ -7,7 +7,6 @@
#
# Delphi is a generative autoregressive model that not only predicts the future disease rates, but also sample entire disease trajectories one step at a time.
#
# In this notebook, we will use SHAP (SHapley Additive exPlanations) framework to analyse which interaction between diseases that Delphi learned from data and how these interaction influence its predicitons.
#
# Let's start by looking at what SHAP values mean:
#
@@ -31,9 +30,7 @@
#
# Without speaking about causality, we can assume that there is *some* connection between brain cancer and death risk. SHAP framework allows using such masking to systematically assess the contribution of each token to the prediction. We can perform this analysis for all trajectories in the dataset and evaluate how, on average, a given disease influences the risk of any other disease.
#
# In case of Delphi, masking means replacing a disease token with "no event" token for all input tokens, except for the sex token that is inverted.
# In[2]:
import os
@@ -64,7 +61,6 @@ delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')
# ## Load model
# In[3]:
out_dir = 'Delphi-2M'
@@ -81,7 +77,6 @@ device_type = 'cuda' if 'cuda' in device else 'cpu'
dtype = {'float32': torch.float32, 'float64': torch.float64, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
# In[4]:
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
@@ -97,7 +92,6 @@ model = model.to(device)
# ## Load data
# In[5]:
from utils import get_batch, get_p2i
@@ -109,7 +103,6 @@ train_p2i = get_p2i(train)
val_p2i = get_p2i(val)
# In[6]:
# define a random example health trajectory
@@ -140,7 +133,6 @@ person = [(a, b * 365.25) for a,b in person]
# ### Individual SHAP values
# In[7]:
# define helper functions
@@ -170,7 +162,6 @@ def get_person(idx):
return person, y, time[0][-1]
# In[8]:
from utils import shap_custom_tokenizer, shap_model_creator
@@ -191,13 +182,11 @@ shap_values = explainer([' '.join(list(map(lambda x: str(token_to_id[x]), person
shap_values.data = np.array([list(map(lambda x: f"{delphi_labels['name'].values[token_to_id[x[0]]]}({x[1]/365:.1f} years) ", person_to_process))])
# In[9]:
out = shap.plots.text(shap_values, display=True) # sometimes this interactive plot can't be rendered well (eg in VS Code, feel free to skip it)
# In[10]:
# SHAP values can be interpreted as how much each input token changes predicted logit corresponding to a particular disease.
@@ -221,7 +210,6 @@ with plt.style.context('default'):
#
# The small synthetic dataset is not enough to properly run following part; if you have access to the full dataset, run `shap-agg-eval.py` to evaluate SHAP values for the entire dataset.
# In[11]:
import pickle
@@ -233,7 +221,6 @@ all_tokens = shap_pkl['tokens']
all_values = shap_pkl['values']
# In[12]:
import pandas as pd
@@ -245,7 +232,6 @@ df_shap = pd.DataFrame(all_values)
df_shap['token'] = all_tokens.astype('int')
# In[13]:
token_count_dict = df_shap['token'].value_counts().sort_index().to_dict()
@@ -266,7 +252,6 @@ df_shap_agg = df_shap[df_shap['token'].apply(lambda x: token_count_dict[x] > N_m
#
# Let's see which diseases increase disease risk the most and also which diseases are most influenced by being a heavy smoker.
# In[14]:
import matplotlib.pyplot as plt
@@ -322,7 +307,6 @@ plot_shap_distribution(
)
# In[15]:
target_token = 9
@@ -362,7 +346,6 @@ plot_shap_distribution(
#
# Now, we will aggregate SHAP values within the pairs, additionally separating them by the time between the "predictor" and "predicted" diseases.
# In[16]:
d = get_batch(range(len(np.unique(shap_pkl['people']))), val, val_p2i,
@@ -370,7 +353,6 @@ d = get_batch(range(len(np.unique(shap_pkl['people']))), val, val_p2i,
device='cpu', padding='regular')
# In[17]:
has_gender = torch.isin(d[0], torch.tensor([2, 3])).any(dim=1).numpy()
@@ -378,7 +360,6 @@ is_male = torch.isin(d[0], torch.tensor([3])).any(dim=1).numpy()
is_female = torch.isin(d[0], torch.tensor([2])).any(dim=1).numpy()
# In[18]:
def get_person(idx):
@@ -394,7 +375,6 @@ def get_person(idx):
return person, y, time[0][-1]
# In[19]:
# the shap result pickle does not contain time, so we need to add it
@@ -419,7 +399,6 @@ for p in tqdm(np.unique(shap_pkl['people'])):
assert len(ages) == len(df_shap)
# In[20]:
all_tokens = shap_pkl['tokens']
@@ -441,7 +420,6 @@ df_shap = df_shap[df_shap['reg_time_years'] > 0]
token_count_dict = df_shap['token'].value_counts().sort_index().to_dict()
# In[22]:
import numpy as np
@@ -458,7 +436,6 @@ def bins_avg(x, y, grid_size=3):
return bin_edges, bin_avgs
# In[23]:
tokens_of_interest = [46, 95, 1168, 1188, 173, 214, 305, 505, 584]
@@ -508,7 +485,6 @@ for num_g, token_group in enumerate(np.array_split(tokens_of_interest, n_groups)
#
# Let's plot two separate heatmaps, one for the cases where the "predictor" disease occured in the past 5 years (with the "predicted disease being the reference) and one for the cases where it occured more than 10 years ago.
# In[24]:
N_min = 5
@@ -527,7 +503,6 @@ df_shap_agg_below_5y = df_shap[df_shap['token'].apply(lambda x: x in columns_mor
df_shap_agg_over_10y = df_shap[df_shap['token'].apply(lambda x: x in columns_more_N) & (df_shap['Time, years'] > 10)].groupby('token').mean()[columns_more_N]
# In[25]:
from matplotlib.colors import LogNorm
@@ -612,18 +587,15 @@ def plot_full_shap_heatmap(cur_df, title):
plt.show()
# In[26]:
plot_full_shap_heatmap(df_shap_agg_below_5y, 'Influence of tokens from below 5 years,\nrisk increase, folds')
# In[27]:
plot_full_shap_heatmap(df_shap_agg_over_10y, 'Influence of tokens from above 10 years,\nrisk increase, folds')
# Interestingly, the resulting heatmap has a block diagonal structure, meaning that within a chapter the interactions between diseases tend to be stronger than between chapters.
#
# The second ("above 10 years") heatmap is also more pale, meaning that most of the disease-disease interactions get weaker over time.