diff --git a/evaluate_models.ipynb b/evaluate_models.ipynb
new file mode 100644
index 0000000..2a3845a
--- /dev/null
+++ b/evaluate_models.ipynb
@@ -0,0 +1,341 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "08bd379a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import torch\n",
+ "from models import TimeAwareGPT2\n",
+ "from utils import load_model\n",
+ "from tqdm import tqdm\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import textwrap\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "plt.rcParams['figure.facecolor'] = 'white'\n",
+ "plt.rcParams.update({'axes.grid': True,\n",
+ " 'grid.linestyle': ':',\n",
+ " 'axes.spines.bottom': False,\n",
+ " 'axes.spines.left': False,\n",
+ " 'axes.spines.right': False,\n",
+ " 'axes.spines.top': False})\n",
+ "plt.rcParams['figure.dpi'] = 72\n",
+ "plt.rcParams['pdf.fonttype'] = 42\n",
+ "\n",
+ "#Green\n",
+ "light_male = '#BAEBE3'\n",
+ "normal_male = '#0FB8A1'\n",
+ "dark_male = '#00574A'\n",
+ "\n",
+ "\n",
+ "#Purple\n",
+ "light_female = '#DEC7FF'\n",
+ "normal_female = '#8520F1'\n",
+ "dark_female = '#7A00BF'\n",
+ "\n",
+ " \n",
+ "delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "1d8375ab",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " name | \n",
+ " ICD-10 Chapter (short) | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 46 | \n",
+ " A41 Other septicaemia | \n",
+ " I. Infectious Diseases | \n",
+ "
\n",
+ " \n",
+ " 95 | \n",
+ " B01 Varicella [chickenpox] | \n",
+ " I. Infectious Diseases | \n",
+ "
\n",
+ " \n",
+ " 1168 | \n",
+ " C25 Malignant neoplasm of pancreas | \n",
+ " II. Neoplasms | \n",
+ "
\n",
+ " \n",
+ " 1188 | \n",
+ " C50 Malignant neoplasm of breast | \n",
+ " II. Neoplasms | \n",
+ "
\n",
+ " \n",
+ " 374 | \n",
+ " G30 Alzheimer's disease | \n",
+ " VI. Nervous System Diseases | \n",
+ "
\n",
+ " \n",
+ " 214 | \n",
+ " E10 Insulin-dependent diabetes mellitus | \n",
+ " IV. Metabolic Diseases | \n",
+ "
\n",
+ " \n",
+ " 305 | \n",
+ " F32 Depressive episode | \n",
+ " V. Mental Disorders | \n",
+ "
\n",
+ " \n",
+ " 505 | \n",
+ " I21 Acute myocardial infarction | \n",
+ " IX. Circulatory Diseases | \n",
+ "
\n",
+ " \n",
+ " 603 | \n",
+ " J45 Asthma | \n",
+ " X. Respiratory Diseases | \n",
+ "
\n",
+ " \n",
+ " 1269 | \n",
+ " Death | \n",
+ " Death | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " name ICD-10 Chapter (short)\n",
+ "46 A41 Other septicaemia I. Infectious Diseases\n",
+ "95 B01 Varicella [chickenpox] I. Infectious Diseases\n",
+ "1168 C25 Malignant neoplasm of pancreas II. Neoplasms\n",
+ "1188 C50 Malignant neoplasm of breast II. Neoplasms\n",
+ "374 G30 Alzheimer's disease VI. Nervous System Diseases\n",
+ "214 E10 Insulin-dependent diabetes mellitus IV. Metabolic Diseases\n",
+ "305 F32 Depressive episode V. Mental Disorders\n",
+ "505 I21 Acute myocardial infarction IX. Circulatory Diseases\n",
+ "603 J45 Asthma X. Respiratory Diseases\n",
+ "1269 Death Death"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Delphi is capable of predicting the disease risk for 1,256 diseases from ICD-10 plus death. \n",
+ "# For illustrative purposes, some of the plots will focus on a subset of 10 selected diseases - the same subset in used in the Delphi paper. \n",
+ "\n",
+ "diseases_of_interest = [46, 95, 1168, 1188, 374, 214, 305, 505, 603, 1269]\n",
+ "delphi_labels.iloc[diseases_of_interest][['name', 'ICD-10 Chapter (short)']]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "f8d86352",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Using device: cpu\n",
+ "Model config: {'n_layer': 12, 'n_embd': 120, 'n_head': 12, 'max_epoch': 200, 'batch_size': 128, 'lr_initial': 0.0006, 'lr_final': 6e-05, 'weight_decay': 0.2, 'warmup_epochs': 10, 'early_stopping_patience': 10, 'pdrop': 0.0, 'token_pdrop': 0.0, 'betas': [0.9, 0.99]}\n",
+ "Model loaded from best_model_n_embd_120_n_layer_12_n_head_12.pt with 2.40M parameters.\n"
+ ]
+ }
+ ],
+ "source": [
+ "seed = 1337\n",
+ "\n",
+ "torch.manual_seed(seed)\n",
+ "torch.cuda.manual_seed(seed)\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "print(f'Using device: {device}')\n",
+ "model = load_model('config_n_embd_120_n_layer_12_n_head_12.json',\n",
+ " 'best_model_n_embd_120_n_layer_12_n_head_12.pt', \n",
+ " 1270)\n",
+ "model.eval()\n",
+ "model = model.to(device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "a8658fb6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Let's try to use the loaded model to extrapolate a partial health trajectory.\n",
+ "\n",
+ "example_health_trajectory = [\n",
+ " ('Male', 0),\n",
+ " ('B01 Varicella [chickenpox]',2),\n",
+ " ('L20 Atopic dermatitis',3),\n",
+ " ('No event', 5),\n",
+ " ('No event', 10),\n",
+ " ('No event', 15),\n",
+ " ('No event', 20),\n",
+ " ('G43 Migraine', 20),\n",
+ " ('E73 Lactose intolerance',21),\n",
+ " ('B27 Infectious mononucleosis',22),\n",
+ " ('No event', 25),\n",
+ " ('J11 Influenza, virus not identified',28),\n",
+ " ('No event', 30),\n",
+ " ('No event', 35),\n",
+ " ('No event', 40),\n",
+ " ('Smoking low', 41),\n",
+ " ('BMI mid', 41),\n",
+ " ('Alcohol low', 41),\n",
+ " ('No event', 42),\n",
+ "]\n",
+ "example_health_trajectory = [(a, b * 365.25) for a,b in example_health_trajectory] "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "d7f2e4a1",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Input trajectory:\n",
+ "0.0: Male\n",
+ "2.0: B01 Varicella [chickenpox]\n",
+ "3.0: L20 Atopic dermatitis\n",
+ "5.0: No event\n",
+ "10.0: No event\n",
+ "15.0: No event\n",
+ "20.0: No event\n",
+ "20.0: G43 Migraine\n",
+ "21.0: E73 Lactose intolerance\n",
+ "22.0: B27 Infectious mononucleosis\n",
+ "25.0: No event\n",
+ "28.0: J11 Influenza, virus not identified\n",
+ "30.0: No event\n",
+ "35.0: No event\n",
+ "40.0: No event\n",
+ "41.0: Smoking low\n",
+ "41.0: BMI mid\n",
+ "41.0: Alcohol low\n",
+ "42.0: No event\n",
+ "=====================\n",
+ "Generated trajectory:\n",
+ "44.9: E03 Other hypothyroidism\n",
+ "45.2: M19 Other arthrosis\n",
+ "45.7: H60 Otitis externa\n",
+ "47.1: H10 Conjunctivitis\n",
+ "47.6: H66 Suppurative and unspecified otitis media\n",
+ "48.4: M25 Other joint disorders, not elsewhere classified\n",
+ "49.8: M79 Other soft tissue disorders, not elsewhere classified\n",
+ "50.9: M65 Synovitis and tenosynovitis\n",
+ "51.2: D50 Iron deficiency anaemia\n",
+ "51.9: K57 Diverticular disease of intestine\n",
+ "52.6: K63 Other diseases of intestine\n",
+ "52.8: K66 Other disorders of peritoneum\n",
+ "52.9: A09 Diarrhoea and gastro-enteritis of presumed infectious origin\n",
+ "53.8: K20 Oesophagitis\n",
+ "53.8: K29 Gastritis and duodenitis\n",
+ "53.9: L30 Other dermatitis\n",
+ "54.5: M54 Dorsalgia\n",
+ "54.8: M15 Polyarthrosis\n",
+ "56.1: N30 Cystitis\n",
+ "56.3: F07 Personality and behavioural disorders due to brain disease, damage and dysfunction\n",
+ "56.4: K80 Cholelithiasis\n",
+ "56.4: M13 Other arthritis\n",
+ "56.6: B37 Candidiasis\n",
+ "58.0: E11 Non-insulin-dependent diabetes mellitus\n",
+ "58.0: K59 Other functional intestinal disorders\n",
+ "60.5: L03 Cellulitis\n",
+ "60.8: K52 Other non-infective gastro-enteritis and colitis\n",
+ "61.4: K31 Other diseases of stomach and duodenum\n",
+ "61.8: K74 Fibrosis and cirrhosis of liver\n",
+ "61.8: Death\n"
+ ]
+ }
+ ],
+ "source": [
+ "max_new_tokens = 100\n",
+ "\n",
+ "name_to_token_id = {row[1]['name']: row[1]['index'] for row in delphi_labels.iterrows()}\n",
+ "\n",
+ "events = [name_to_token_id[event[0]] for event in example_health_trajectory]\n",
+ "events = torch.tensor(events, device=device).unsqueeze(0)\n",
+ "ages = [event[1] for event in example_health_trajectory]\n",
+ "ages = torch.tensor(ages, device=device).unsqueeze(0)\n",
+ "\n",
+ "res = []\n",
+ "with torch.no_grad():\n",
+ " y,b,_ = model.generate(events, ages, max_new_tokens, termination_tokens=[1269])\n",
+ " # Convert model outputs to readable format\n",
+ " events_data = zip(y.cpu().numpy().flatten(), b.cpu().numpy().flatten()/365.)\n",
+ " \n",
+ " print('Input trajectory:')\n",
+ " for i, (event_id, age_years) in enumerate(events_data):\n",
+ " if i == len(example_health_trajectory):\n",
+ " print('=====================')\n",
+ " print('Generated trajectory:')\n",
+ " event_name = delphi_labels.loc[event_id, 'name']\n",
+ " print(f\"{age_years:2.1f}: {event_name}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0b937d75",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "base",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/utils.py b/utils.py
index a3954d9..bab1538 100644
--- a/utils.py
+++ b/utils.py
@@ -122,6 +122,8 @@ def load_model(config_path, model_path, vocab_size, device='cpu'):
with open(config_path, 'r') as f:
config_dict = json.load(f)
+ print(f"Model config: {config_dict}")
+
# Create a config object from the dictionary
class AttrDict(dict):
def __init__(self, *args, **kwargs):