Refactor: Remove Jupyter Notebook cell markers
This commit is contained in:
@@ -25,7 +25,6 @@
|
||||
#
|
||||
#
|
||||
|
||||
# In[2]:
|
||||
|
||||
|
||||
import os
|
||||
@@ -63,7 +62,6 @@ dark_female = '#7A00BF'
|
||||
delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')
|
||||
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
# Delphi is capable of predicting the disease risk for 1,256 diseases from ICD-10 plus death.
|
||||
@@ -75,7 +73,6 @@ delphi_labels.iloc[diseases_of_interest][['name', 'ICD-10 Chapter (short)']]
|
||||
|
||||
# ## Load model
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
out_dir = 'Delphi-2M'
|
||||
@@ -88,7 +85,6 @@ torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
||||
@@ -104,7 +100,6 @@ model = model.to(device)
|
||||
checkpoint['model_args']
|
||||
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
# Let's try to use the loaded model to extrapolate a partial health trajectory.
|
||||
@@ -133,7 +128,6 @@ example_health_trajectory = [
|
||||
example_health_trajectory = [(a, b * 365.25) for a,b in example_health_trajectory]
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
max_new_tokens = 100
|
||||
@@ -175,7 +169,6 @@ with torch.no_grad():
|
||||
#
|
||||
# No-event tokens eliminate long time intervals without tokens, which are typical for younger ages, when people generally have fewer diseases and therefore less medical records. Transformers predict the text token probability distribution only at the time of currently observed tokens, hence, no-event tokens can also be inserted during inference to obtain the predicted disease risk at any given time of interest.
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
from utils import get_batch, get_p2i
|
||||
@@ -194,7 +187,6 @@ dataset_subset_size = 2000 # len(val_p2i) # can be set to smaller number (e.g. 2
|
||||
|
||||
# ## Calibration of predicted times
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
# Fetch a bit of data and calculate future disease rates from it
|
||||
@@ -204,13 +196,11 @@ with torch.no_grad():
|
||||
t = (d[3]-d[1])[:,:].cpu().numpy().squeeze()
|
||||
|
||||
|
||||
# In[10]:
|
||||
|
||||
|
||||
from scipy.special import logsumexp
|
||||
|
||||
# Calculate expected waiting times from model predictions using competing exponentials theory
|
||||
# In Delphi's framework, each possible event has an exponential distribution with rate λᵢ = exp(logits[i])
|
||||
# The expected time until any event occurs is 1/sum(λᵢ) = 1/exp(logsumexp(logits))
|
||||
# logsumexp provides numerical stability vs. calculating exp(logits) directly
|
||||
|
||||
@@ -250,7 +240,6 @@ plt.gca().tick_params(length=1.15, width=0.3, labelsize=8, grid_alpha=0.0, grid_
|
||||
|
||||
# ## Incidence
|
||||
|
||||
# In[11]:
|
||||
|
||||
|
||||
## Load large chunk of data
|
||||
@@ -267,7 +256,6 @@ d = get_batch(range(subset_size), val, val_p2i,
|
||||
device=device, padding='random')
|
||||
|
||||
|
||||
# In[12]:
|
||||
|
||||
|
||||
# 2 is female token, 3 is male token
|
||||
@@ -277,7 +265,6 @@ is_female = (d[0] == 2).any(axis=1).cpu().numpy()
|
||||
has_gender = is_male | is_female
|
||||
|
||||
|
||||
# In[13]:
|
||||
|
||||
|
||||
# lets split the large data chanks to smaller batches and calculate the logits for the whole dataset
|
||||
@@ -295,7 +282,6 @@ d = [d_.cpu() for d_ in d]
|
||||
|
||||
# ### Age-sex incidence baseline
|
||||
|
||||
# In[14]:
|
||||
|
||||
|
||||
# calculate disease incidence rates for each disease, given age and sex
|
||||
@@ -334,7 +320,6 @@ females_in_ukb = np.cumsum(np.histogram((females[ukb_condition, 1]/365.25).astyp
|
||||
#
|
||||
# Bright dots are often located above the population average rates, which indicates that Delphi correctly captures the elevated disease risk for such participants.
|
||||
|
||||
# In[15]:
|
||||
|
||||
|
||||
def plot_age_incidence(ix, d, p, highlight_idx=0):
|
||||
@@ -444,7 +429,6 @@ def plot_age_incidence(ix, d, p, highlight_idx=0):
|
||||
axf[i].legend(loc='center left', bbox_to_anchor=(1.05, 0.5))
|
||||
|
||||
|
||||
# In[16]:
|
||||
|
||||
|
||||
plot_age_incidence(diseases_of_interest, d, p, highlight_idx=0)
|
||||
@@ -463,7 +447,6 @@ plt.show()
|
||||
# 5. Plot the calibration curve
|
||||
#
|
||||
|
||||
# In[17]:
|
||||
|
||||
|
||||
def auc(x1, x2):
|
||||
@@ -477,7 +460,6 @@ def auc(x1, x2):
|
||||
return U1 / n1 / n2
|
||||
|
||||
|
||||
# In[18]:
|
||||
|
||||
|
||||
d100k = get_batch(range(dataset_subset_size), val, val_p2i,
|
||||
@@ -485,7 +467,6 @@ d100k = get_batch(range(dataset_subset_size), val, val_p2i,
|
||||
device=device, padding='random')
|
||||
|
||||
|
||||
# In[19]:
|
||||
|
||||
|
||||
p100k = []
|
||||
@@ -497,7 +478,6 @@ with torch.no_grad():
|
||||
p100k = np.vstack(p100k)
|
||||
|
||||
|
||||
# In[20]:
|
||||
|
||||
|
||||
import scipy
|
||||
@@ -626,7 +606,6 @@ def plot_calibration(disease_idx, data, logits, offset = 365.25, age_groups=rang
|
||||
return out
|
||||
|
||||
|
||||
# In[21]:
|
||||
|
||||
|
||||
out = []
|
||||
@@ -649,7 +628,6 @@ for j, k in enumerate(diseases_of_interest):
|
||||
plt.show()
|
||||
|
||||
|
||||
# In[22]:
|
||||
|
||||
|
||||
# the same calibrations curves as above, but a more compact version
|
||||
@@ -693,7 +671,6 @@ plt.show()
|
||||
# 4. Calculate the AUC using Delphi disease rates as predictors
|
||||
# 5. (Optional) Use DeLong's method (recommended) or bootstrap to calculate the variance of the AUC
|
||||
|
||||
# In[23]:
|
||||
|
||||
|
||||
from evaluate_auc import get_calibration_auc, evaluate_auc_pipeline
|
||||
@@ -719,7 +696,6 @@ auc_inputs = {
|
||||
}
|
||||
|
||||
|
||||
# In[24]:
|
||||
|
||||
|
||||
all_aucs = []
|
||||
@@ -745,7 +721,6 @@ for disease_idx_batch, disease_idx in tqdm(enumerate(diseases_of_interest), tota
|
||||
all_aucs.append(out_item)
|
||||
|
||||
|
||||
# In[25]:
|
||||
|
||||
|
||||
# this df contains AUC calculations for all diseases, sexes, and age groups
|
||||
@@ -779,7 +754,6 @@ auc_df = auc_df.merge(delphi_labels[['name', 'index']], left_on='token', right_o
|
||||
auc_df
|
||||
|
||||
|
||||
# In[26]:
|
||||
|
||||
|
||||
plt.figure(figsize=(7, 5))
|
||||
@@ -817,7 +791,6 @@ plt.tight_layout()
|
||||
#
|
||||
# Therefore, we well use precomputed results here.
|
||||
|
||||
# In[27]:
|
||||
|
||||
|
||||
# df_auc_unpooled_merged, df_auc_merged = evaluate_auc_pipeline(model,
|
||||
@@ -830,14 +803,12 @@ plt.tight_layout()
|
||||
# )
|
||||
|
||||
|
||||
# In[28]:
|
||||
|
||||
|
||||
df_auc_all_diseases = pd.read_csv('supplementary/delphi_auc.csv')
|
||||
df_auc_all_diseases['mean_auc'] = df_auc_all_diseases[['AUC Female, (no gap)', 'AUC Male, (no gap)']].mean(axis=1)
|
||||
|
||||
|
||||
# In[29]:
|
||||
|
||||
|
||||
plt.figure(figsize=(7, 4))
|
||||
@@ -852,7 +823,6 @@ plt.ylabel('AUC')
|
||||
plt.show()
|
||||
|
||||
|
||||
# In[30]:
|
||||
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -909,7 +879,6 @@ plt.grid(axis='x', visible=False)
|
||||
plt.show()
|
||||
|
||||
|
||||
# In[31]:
|
||||
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -984,7 +953,6 @@ plt.show()
|
||||
#
|
||||
# Attention maps can be used for interpretability, however for a more robust interpretation, we suggest using SHAP values (`shap_analysis.ipynb`).
|
||||
|
||||
# In[32]:
|
||||
|
||||
|
||||
d = get_batch([0], val, val_p2i, select='left', block_size=model.config.block_size, device=device)
|
||||
@@ -1002,7 +970,6 @@ for i in range(att.shape[0]):
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
# In[33]:
|
||||
|
||||
|
||||
d = get_batch(range(dataset_subset_size), val, val_p2i,
|
||||
@@ -1018,7 +985,6 @@ att.shape
|
||||
#
|
||||
# Generally, tokens tend to lose most of their importance pretty quickly. High attention for the most recent token in the trajectory is likely due to this tokens being used by the model to estimate the current age of the patient, which is a very important predictor for the overall disease risk.
|
||||
|
||||
# In[34]:
|
||||
|
||||
|
||||
import textwrap
|
||||
@@ -1045,14 +1011,12 @@ for i in range(len(w[0])):
|
||||
#
|
||||
# We see that diseases cluster by their ICD-10 chapter - which is interesting, because the model had no knowledge about the ICD-10 hierarchy during training; all diseases were treated equally.
|
||||
|
||||
# In[35]:
|
||||
|
||||
|
||||
import umap
|
||||
import matplotlib as mpl
|
||||
|
||||
|
||||
# In[36]:
|
||||
|
||||
|
||||
wte = model.transformer.wte.weight.cpu().detach().numpy()
|
||||
@@ -1064,7 +1028,6 @@ u = u0 - np.median(u0, axis=0)
|
||||
u = - u
|
||||
|
||||
|
||||
# In[37]:
|
||||
|
||||
|
||||
def remove_ticks(ax):
|
||||
@@ -1080,7 +1043,6 @@ def remove_ticks(ax):
|
||||
tick.tick2line.set_visible(False)
|
||||
|
||||
|
||||
# In[38]:
|
||||
|
||||
|
||||
labels_all = pd.read_csv('delphi_labels_chapters_colours_icd.csv')
|
||||
@@ -1094,7 +1056,6 @@ short_names_present = [i for i in short_names if i in labels_non_technical['ICD-
|
||||
color_mapping_short = {k: v for k, v in labels_all[['ICD-10 Chapter (short)', 'color']].values}
|
||||
|
||||
|
||||
# In[39]:
|
||||
|
||||
|
||||
import seaborn as sns
|
||||
|
Reference in New Issue
Block a user