{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "08bd379a", "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "from models import TimeAwareGPT2\n", "from utils import load_model\n", "from tqdm import tqdm\n", "import pandas as pd\n", "import numpy as np\n", "import textwrap\n", "import matplotlib.pyplot as plt\n", "\n", "plt.rcParams['figure.facecolor'] = 'white'\n", "plt.rcParams.update({'axes.grid': True,\n", " 'grid.linestyle': ':',\n", " 'axes.spines.bottom': False,\n", " 'axes.spines.left': False,\n", " 'axes.spines.right': False,\n", " 'axes.spines.top': False})\n", "plt.rcParams['figure.dpi'] = 72\n", "plt.rcParams['pdf.fonttype'] = 42\n", "\n", "#Green\n", "light_male = '#BAEBE3'\n", "normal_male = '#0FB8A1'\n", "dark_male = '#00574A'\n", "\n", "\n", "#Purple\n", "light_female = '#DEC7FF'\n", "normal_female = '#8520F1'\n", "dark_female = '#7A00BF'\n", "\n", " \n", "delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')" ] }, { "cell_type": "code", "execution_count": 2, "id": "1d8375ab", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
nameICD-10 Chapter (short)
46A41 Other septicaemiaI. Infectious Diseases
95B01 Varicella [chickenpox]I. Infectious Diseases
1168C25 Malignant neoplasm of pancreasII. Neoplasms
1188C50 Malignant neoplasm of breastII. Neoplasms
374G30 Alzheimer's diseaseVI. Nervous System Diseases
214E10 Insulin-dependent diabetes mellitusIV. Metabolic Diseases
305F32 Depressive episodeV. Mental Disorders
505I21 Acute myocardial infarctionIX. Circulatory Diseases
603J45 AsthmaX. Respiratory Diseases
1269DeathDeath
\n", "
" ], "text/plain": [ " name ICD-10 Chapter (short)\n", "46 A41 Other septicaemia I. Infectious Diseases\n", "95 B01 Varicella [chickenpox] I. Infectious Diseases\n", "1168 C25 Malignant neoplasm of pancreas II. Neoplasms\n", "1188 C50 Malignant neoplasm of breast II. Neoplasms\n", "374 G30 Alzheimer's disease VI. Nervous System Diseases\n", "214 E10 Insulin-dependent diabetes mellitus IV. Metabolic Diseases\n", "305 F32 Depressive episode V. Mental Disorders\n", "505 I21 Acute myocardial infarction IX. Circulatory Diseases\n", "603 J45 Asthma X. Respiratory Diseases\n", "1269 Death Death" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Delphi is capable of predicting the disease risk for 1,256 diseases from ICD-10 plus death. \n", "# For illustrative purposes, some of the plots will focus on a subset of 10 selected diseases - the same subset in used in the Delphi paper. \n", "\n", "diseases_of_interest = [46, 95, 1168, 1188, 374, 214, 305, 505, 603, 1269]\n", "delphi_labels.iloc[diseases_of_interest][['name', 'ICD-10 Chapter (short)']]" ] }, { "cell_type": "code", "execution_count": 3, "id": "f8d86352", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cpu\n", "Model config: {'n_layer': 12, 'n_embd': 120, 'n_head': 12, 'max_epoch': 200, 'batch_size': 128, 'lr_initial': 0.0006, 'lr_final': 6e-05, 'weight_decay': 0.2, 'warmup_epochs': 10, 'early_stopping_patience': 10, 'pdrop': 0.0, 'token_pdrop': 0.0, 'betas': [0.9, 0.99]}\n", "Model loaded from best_model_n_embd_120_n_layer_12_n_head_12.pt with 2.40M parameters.\n" ] } ], "source": [ "seed = 1337\n", "\n", "torch.manual_seed(seed)\n", "torch.cuda.manual_seed(seed)\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(f'Using device: {device}')\n", "model = load_model('config_n_embd_120_n_layer_12_n_head_12.json',\n", " 'best_model_n_embd_120_n_layer_12_n_head_12.pt', \n", " 1270)\n", "model.eval()\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 4, "id": "a8658fb6", "metadata": {}, "outputs": [], "source": [ "# Let's try to use the loaded model to extrapolate a partial health trajectory.\n", "\n", "example_health_trajectory = [\n", " ('Male', 0),\n", " ('B01 Varicella [chickenpox]',2),\n", " ('L20 Atopic dermatitis',3),\n", " ('No event', 5),\n", " ('No event', 10),\n", " ('No event', 15),\n", " ('No event', 20),\n", " ('G43 Migraine', 20),\n", " ('E73 Lactose intolerance',21),\n", " ('B27 Infectious mononucleosis',22),\n", " ('No event', 25),\n", " ('J11 Influenza, virus not identified',28),\n", " ('No event', 30),\n", " ('No event', 35),\n", " ('No event', 40),\n", " ('Smoking low', 41),\n", " ('BMI mid', 41),\n", " ('Alcohol low', 41),\n", " ('No event', 42),\n", "]\n", "example_health_trajectory = [(a, b * 365.25) for a,b in example_health_trajectory] " ] }, { "cell_type": "code", "execution_count": 8, "id": "d7f2e4a1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input trajectory:\n", "0.0: Male\n", "2.0: B01 Varicella [chickenpox]\n", "3.0: L20 Atopic dermatitis\n", "5.0: No event\n", "10.0: No event\n", "15.0: No event\n", "20.0: No event\n", "20.0: G43 Migraine\n", "21.0: E73 Lactose intolerance\n", "22.0: B27 Infectious mononucleosis\n", "25.0: No event\n", "28.0: J11 Influenza, virus not identified\n", "30.0: No event\n", "35.0: No event\n", "40.0: No event\n", "41.0: Smoking low\n", "41.0: BMI mid\n", "41.0: Alcohol low\n", "42.0: No event\n", "=====================\n", "Generated trajectory:\n", "44.9: E03 Other hypothyroidism\n", "45.2: M19 Other arthrosis\n", "45.7: H60 Otitis externa\n", "47.1: H10 Conjunctivitis\n", "47.6: H66 Suppurative and unspecified otitis media\n", "48.4: M25 Other joint disorders, not elsewhere classified\n", "49.8: M79 Other soft tissue disorders, not elsewhere classified\n", "50.9: M65 Synovitis and tenosynovitis\n", "51.2: D50 Iron deficiency anaemia\n", "51.9: K57 Diverticular disease of intestine\n", "52.6: K63 Other diseases of intestine\n", "52.8: K66 Other disorders of peritoneum\n", "52.9: A09 Diarrhoea and gastro-enteritis of presumed infectious origin\n", "53.8: K20 Oesophagitis\n", "53.8: K29 Gastritis and duodenitis\n", "53.9: L30 Other dermatitis\n", "54.5: M54 Dorsalgia\n", "54.8: M15 Polyarthrosis\n", "56.1: N30 Cystitis\n", "56.3: F07 Personality and behavioural disorders due to brain disease, damage and dysfunction\n", "56.4: K80 Cholelithiasis\n", "56.4: M13 Other arthritis\n", "56.6: B37 Candidiasis\n", "58.0: E11 Non-insulin-dependent diabetes mellitus\n", "58.0: K59 Other functional intestinal disorders\n", "60.5: L03 Cellulitis\n", "60.8: K52 Other non-infective gastro-enteritis and colitis\n", "61.4: K31 Other diseases of stomach and duodenum\n", "61.8: K74 Fibrosis and cirrhosis of liver\n", "61.8: Death\n" ] } ], "source": [ "max_new_tokens = 100\n", "\n", "name_to_token_id = {row[1]['name']: row[1]['index'] for row in delphi_labels.iterrows()}\n", "\n", "events = [name_to_token_id[event[0]] for event in example_health_trajectory]\n", "events = torch.tensor(events, device=device).unsqueeze(0)\n", "ages = [event[1] for event in example_health_trajectory]\n", "ages = torch.tensor(ages, device=device).unsqueeze(0)\n", "\n", "res = []\n", "with torch.no_grad():\n", " y,b,_ = model.generate(events, ages, max_new_tokens, termination_tokens=[1269])\n", " # Convert model outputs to readable format\n", " events_data = zip(y.cpu().numpy().flatten(), b.cpu().numpy().flatten()/365.)\n", " \n", " print('Input trajectory:')\n", " for i, (event_id, age_years) in enumerate(events_data):\n", " if i == len(example_health_trajectory):\n", " print('=====================')\n", " print('Generated trajectory:')\n", " event_name = delphi_labels.loc[event_id, 'name']\n", " print(f\"{age_years:2.1f}: {event_name}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0b937d75", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }