430 lines
42 KiB
Plaintext
430 lines
42 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "08bd379a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from models import TimeAwareGPT2\n",
|
|
"from utils import load_model\n",
|
|
"from tqdm import tqdm\n",
|
|
"import pandas as pd\n",
|
|
"import numpy as np\n",
|
|
"import textwrap\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"plt.rcParams['figure.facecolor'] = 'white'\n",
|
|
"plt.rcParams.update({'axes.grid': True,\n",
|
|
" 'grid.linestyle': ':',\n",
|
|
" 'axes.spines.bottom': False,\n",
|
|
" 'axes.spines.left': False,\n",
|
|
" 'axes.spines.right': False,\n",
|
|
" 'axes.spines.top': False})\n",
|
|
"plt.rcParams['figure.dpi'] = 72\n",
|
|
"plt.rcParams['pdf.fonttype'] = 42\n",
|
|
"\n",
|
|
"#Green\n",
|
|
"light_male = '#BAEBE3'\n",
|
|
"normal_male = '#0FB8A1'\n",
|
|
"dark_male = '#00574A'\n",
|
|
"\n",
|
|
"\n",
|
|
"#Purple\n",
|
|
"light_female = '#DEC7FF'\n",
|
|
"normal_female = '#8520F1'\n",
|
|
"dark_female = '#7A00BF'\n",
|
|
"\n",
|
|
" \n",
|
|
"delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "1d8375ab",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>name</th>\n",
|
|
" <th>ICD-10 Chapter (short)</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>46</th>\n",
|
|
" <td>A41 Other septicaemia</td>\n",
|
|
" <td>I. Infectious Diseases</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>95</th>\n",
|
|
" <td>B01 Varicella [chickenpox]</td>\n",
|
|
" <td>I. Infectious Diseases</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1168</th>\n",
|
|
" <td>C25 Malignant neoplasm of pancreas</td>\n",
|
|
" <td>II. Neoplasms</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1188</th>\n",
|
|
" <td>C50 Malignant neoplasm of breast</td>\n",
|
|
" <td>II. Neoplasms</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>374</th>\n",
|
|
" <td>G30 Alzheimer's disease</td>\n",
|
|
" <td>VI. Nervous System Diseases</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>214</th>\n",
|
|
" <td>E10 Insulin-dependent diabetes mellitus</td>\n",
|
|
" <td>IV. Metabolic Diseases</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>305</th>\n",
|
|
" <td>F32 Depressive episode</td>\n",
|
|
" <td>V. Mental Disorders</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>505</th>\n",
|
|
" <td>I21 Acute myocardial infarction</td>\n",
|
|
" <td>IX. Circulatory Diseases</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>603</th>\n",
|
|
" <td>J45 Asthma</td>\n",
|
|
" <td>X. Respiratory Diseases</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1269</th>\n",
|
|
" <td>Death</td>\n",
|
|
" <td>Death</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" name ICD-10 Chapter (short)\n",
|
|
"46 A41 Other septicaemia I. Infectious Diseases\n",
|
|
"95 B01 Varicella [chickenpox] I. Infectious Diseases\n",
|
|
"1168 C25 Malignant neoplasm of pancreas II. Neoplasms\n",
|
|
"1188 C50 Malignant neoplasm of breast II. Neoplasms\n",
|
|
"374 G30 Alzheimer's disease VI. Nervous System Diseases\n",
|
|
"214 E10 Insulin-dependent diabetes mellitus IV. Metabolic Diseases\n",
|
|
"305 F32 Depressive episode V. Mental Disorders\n",
|
|
"505 I21 Acute myocardial infarction IX. Circulatory Diseases\n",
|
|
"603 J45 Asthma X. Respiratory Diseases\n",
|
|
"1269 Death Death"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Delphi is capable of predicting the disease risk for 1,256 diseases from ICD-10 plus death. \n",
|
|
"# For illustrative purposes, some of the plots will focus on a subset of 10 selected diseases - the same subset in used in the Delphi paper. \n",
|
|
"\n",
|
|
"diseases_of_interest = [46, 95, 1168, 1188, 374, 214, 305, 505, 603, 1269]\n",
|
|
"delphi_labels.iloc[diseases_of_interest][['name', 'ICD-10 Chapter (short)']]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "f8d86352",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Using device: cpu\n",
|
|
"Model config: {'n_layer': 16, 'n_embd': 256, 'n_head': 16, 'max_epoch': 200, 'batch_size': 128, 'lr_initial': 0.0006, 'lr_final': 6e-05, 'weight_decay': 0.2, 'warmup_epochs': 10, 'early_stopping_patience': 10, 'pdrop': 0.0, 'token_pdrop': 0.0, 'betas': [0.9, 0.99]}\n",
|
|
"Model loaded from best_model_n_embd_256_n_layer_16_n_head_16.pt with 13.29M parameters.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"seed = 1337\n",
|
|
"\n",
|
|
"torch.manual_seed(seed)\n",
|
|
"torch.cuda.manual_seed(seed)\n",
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
"print(f'Using device: {device}')\n",
|
|
"model = load_model('config_n_embd_256_n_layer_16_n_head_16.json',\n",
|
|
" 'best_model_n_embd_256_n_layer_16_n_head_16.pt', \n",
|
|
" 1270)\n",
|
|
"model.eval()\n",
|
|
"model = model.to(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "a8658fb6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let's try to use the loaded model to extrapolate a partial health trajectory.\n",
|
|
"\n",
|
|
"example_health_trajectory = [\n",
|
|
" ('Male', 0),\n",
|
|
" ('B01 Varicella [chickenpox]',2),\n",
|
|
" ('L20 Atopic dermatitis',3),\n",
|
|
" ('No event', 5),\n",
|
|
" ('No event', 10),\n",
|
|
" ('No event', 15),\n",
|
|
" ('No event', 20),\n",
|
|
" ('G43 Migraine', 20),\n",
|
|
" ('E73 Lactose intolerance',21),\n",
|
|
" ('B27 Infectious mononucleosis',22),\n",
|
|
" ('No event', 25),\n",
|
|
" ('J11 Influenza, virus not identified',28),\n",
|
|
" ('No event', 30),\n",
|
|
" ('No event', 35),\n",
|
|
" ('No event', 40),\n",
|
|
" ('Smoking low', 41),\n",
|
|
" ('BMI mid', 41),\n",
|
|
" ('Alcohol low', 41),\n",
|
|
" ('No event', 42),\n",
|
|
"]\n",
|
|
"example_health_trajectory = [(a, b * 365.25) for a,b in example_health_trajectory] "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "d7f2e4a1",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Input trajectory:\n",
|
|
"0.0: Male\n",
|
|
"2.0: B01 Varicella [chickenpox]\n",
|
|
"3.0: L20 Atopic dermatitis\n",
|
|
"5.0: No event\n",
|
|
"10.0: No event\n",
|
|
"15.0: No event\n",
|
|
"20.0: No event\n",
|
|
"20.0: G43 Migraine\n",
|
|
"21.0: E73 Lactose intolerance\n",
|
|
"22.0: B27 Infectious mononucleosis\n",
|
|
"25.0: No event\n",
|
|
"28.0: J11 Influenza, virus not identified\n",
|
|
"30.0: No event\n",
|
|
"35.0: No event\n",
|
|
"40.0: No event\n",
|
|
"41.0: Smoking low\n",
|
|
"41.0: BMI mid\n",
|
|
"41.0: Alcohol low\n",
|
|
"42.0: No event\n",
|
|
"=====================\n",
|
|
"Generated trajectory:\n",
|
|
"43.9: B95 Streptococcus and staphylococcus as the cause of diseases classified to other chapters\n",
|
|
"44.6: J30 Vasomotor and allergic rhinitis\n",
|
|
"44.8: L03 Cellulitis\n",
|
|
"46.6: L30 Other dermatitis\n",
|
|
"46.8: K56 Paralytic ileus and intestinal obstruction without hernia\n",
|
|
"47.8: K76 Other diseases of liver\n",
|
|
"48.8: N20 Calculus of kidney and ureter\n",
|
|
"48.8: L24 Irritant contact dermatitis\n",
|
|
"49.5: I10 Essential primary hypertension\n",
|
|
"49.5: K59 Other functional intestinal disorders\n",
|
|
"50.6: B96 Other bacterial agents as the cause of diseases classified to other chapters\n",
|
|
"51.4: E14 Unspecified diabetes mellitus\n",
|
|
"51.6: E55 Vitamin d deficiency\n",
|
|
"51.7: A41 Other septicaemia\n",
|
|
"52.0: F41 Other anxiety disorders\n",
|
|
"52.1: F32 Depressive episode\n",
|
|
"52.3: M19 Other arthrosis\n",
|
|
"53.5: G96 Other disorders of central nervous system\n",
|
|
"53.7: E87 Other disorders of fluid, electrolyte and acid-base balance\n",
|
|
"53.8: J18 Pneumonia, organism unspecified\n",
|
|
"53.8: E66 Obesity\n",
|
|
"53.8: E11 Non-insulin-dependent diabetes mellitus\n",
|
|
"54.4: Death\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"max_new_tokens = 100\n",
|
|
"\n",
|
|
"name_to_token_id = {row[1]['name']: row[1]['index'] for row in delphi_labels.iterrows()}\n",
|
|
"\n",
|
|
"events = [name_to_token_id[event[0]] for event in example_health_trajectory]\n",
|
|
"events = torch.tensor(events, device=device).unsqueeze(0)\n",
|
|
"ages = [event[1] for event in example_health_trajectory]\n",
|
|
"ages = torch.tensor(ages, device=device).unsqueeze(0)\n",
|
|
"\n",
|
|
"res = []\n",
|
|
"with torch.no_grad():\n",
|
|
" y,b,_ = model.generate(events, ages, max_new_tokens, termination_tokens=[1269])\n",
|
|
" # Convert model outputs to readable format\n",
|
|
" events_data = zip(y.cpu().numpy().flatten(), b.cpu().numpy().flatten()/365.)\n",
|
|
" \n",
|
|
" print('Input trajectory:')\n",
|
|
" for i, (event_id, age_years) in enumerate(events_data):\n",
|
|
" if i == len(example_health_trajectory):\n",
|
|
" print('=====================')\n",
|
|
" print('Generated trajectory:')\n",
|
|
" event_name = delphi_labels.loc[event_id, 'name']\n",
|
|
" print(f\"{age_years:2.1f}: {event_name}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "0b937d75",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from utils import PatientEventDataset, get_batch\n",
|
|
"train_data_path = 'ukb_real_train.bin'\n",
|
|
"val_data_path = 'ukb_real_val.bin'\n",
|
|
"train_data_arr = np.memmap(train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)\n",
|
|
"val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)\n",
|
|
"block_length = 128\n",
|
|
"train_dataset = PatientEventDataset(train_data_arr, block_length)\n",
|
|
"val_dataset = PatientEventDataset(val_data_arr, block_length)\n",
|
|
"dataset_subset_size = 2048"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "6a3bb2dc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"input_events, input_times, target_events, target_times = get_batch(val_dataset, slice(256))\n",
|
|
"with torch.no_grad():\n",
|
|
" p = model(input_events.to(device), input_times.to(device))\n",
|
|
" p = p.cpu().detach().numpy().squeeze()\n",
|
|
"t = (target_times-input_times).cpu().numpy().squeeze()\n",
|
|
"target_events = target_events.cpu().numpy().squeeze()\n",
|
|
"ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]\n",
|
|
"mask = ~np.isin(target_events, ignored_token_ids)\n",
|
|
"p = p[mask]\n",
|
|
"t = t[mask]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "1c32bd6e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 288x288 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"from scipy.special import logsumexp\n",
|
|
"\n",
|
|
"# Calculate expected waiting times from model predictions using competing exponentials theory\n",
|
|
"# In Delphi's framework, each possible event has an exponential distribution with rate λᵢ = exp(logits[i])\n",
|
|
"# The expected time until any event occurs is 1/sum(λᵢ) = 1/exp(logsumexp(logits))\n",
|
|
"# logsumexp provides numerical stability vs. calculating exp(logits) directly\n",
|
|
"\n",
|
|
"# Let's see how the predicted waiting times compare to the observed waiting times\n",
|
|
"\n",
|
|
"plt.figure(figsize=(4, 4))\n",
|
|
"# Calculate expected time to next token (inverse of hazard rate)\n",
|
|
"expected_t = 1/np.exp(logsumexp(p, axis=-1))\n",
|
|
"\n",
|
|
"# Define bin width for logarithmic binning\n",
|
|
"delta_log_t = 0.1\n",
|
|
"log_range = np.arange(1.75, 4, delta_log_t)\n",
|
|
"\n",
|
|
"# Calculate average observed time for each logarithmic bin\n",
|
|
"observed_t = []\n",
|
|
"for i in log_range:\n",
|
|
" # Create mask for current bin and valid times\n",
|
|
" bin_mask = (expected_t > 10**i) & (expected_t <= 10**(i+delta_log_t)) & (t > 0)\n",
|
|
" # Calculate mean for this bin\n",
|
|
" bin_mean = t[bin_mask].mean() if bin_mask.sum() > 0 else np.nan\n",
|
|
" observed_t.append(bin_mean)\n",
|
|
"plt.axes().set_aspect('equal')\n",
|
|
"plt.scatter(expected_t, t+0.5, marker=\".\", c='lightgrey', rasterized=True)\n",
|
|
"plt.xlabel('Expected days to next token')\n",
|
|
"plt.ylabel('Observed days to next token')\n",
|
|
"plt.plot(10**(np.arange(1.75,4,delta_log_t)+delta_log_t/2.),observed_t, label='average')\n",
|
|
"plt.yscale('log')\n",
|
|
"plt.xscale('log')\n",
|
|
"plt.legend()\n",
|
|
"plt.xlim(1,2e3)\n",
|
|
"plt.ylim(1,2e3)\n",
|
|
"plt.plot([0,1],[0,1], transform = plt.gca().transAxes, c='k' , ls=(0, (5, 5)), linewidth=0.7)\n",
|
|
"\n",
|
|
"plt.gca().tick_params(length=1.15, width=0.3, labelsize=8, grid_alpha=1, grid_linewidth=0.45, grid_linestyle=':')\n",
|
|
"plt.gca().tick_params(length=1.15, width=0.3, labelsize=8, grid_alpha=0.0, grid_linewidth=0.35, which='minor')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4e5e3c8f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "base",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|