From 1c9e2a2fb3abdc6934ee76b561e8d92ef0e46d7c Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Mon, 20 Oct 2025 10:14:50 +0800 Subject: [PATCH] feat: print model config and add evaluation notebook --- evaluate_models.ipynb | 341 ++++++++++++++++++++++++++++++++++++++++++ utils.py | 2 + 2 files changed, 343 insertions(+) create mode 100644 evaluate_models.ipynb diff --git a/evaluate_models.ipynb b/evaluate_models.ipynb new file mode 100644 index 0000000..2a3845a --- /dev/null +++ b/evaluate_models.ipynb @@ -0,0 +1,341 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "08bd379a", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "from models import TimeAwareGPT2\n", + "from utils import load_model\n", + "from tqdm import tqdm\n", + "import pandas as pd\n", + "import numpy as np\n", + "import textwrap\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.rcParams['figure.facecolor'] = 'white'\n", + "plt.rcParams.update({'axes.grid': True,\n", + " 'grid.linestyle': ':',\n", + " 'axes.spines.bottom': False,\n", + " 'axes.spines.left': False,\n", + " 'axes.spines.right': False,\n", + " 'axes.spines.top': False})\n", + "plt.rcParams['figure.dpi'] = 72\n", + "plt.rcParams['pdf.fonttype'] = 42\n", + "\n", + "#Green\n", + "light_male = '#BAEBE3'\n", + "normal_male = '#0FB8A1'\n", + "dark_male = '#00574A'\n", + "\n", + "\n", + "#Purple\n", + "light_female = '#DEC7FF'\n", + "normal_female = '#8520F1'\n", + "dark_female = '#7A00BF'\n", + "\n", + " \n", + "delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1d8375ab", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nameICD-10 Chapter (short)
46A41 Other septicaemiaI. Infectious Diseases
95B01 Varicella [chickenpox]I. Infectious Diseases
1168C25 Malignant neoplasm of pancreasII. Neoplasms
1188C50 Malignant neoplasm of breastII. Neoplasms
374G30 Alzheimer's diseaseVI. Nervous System Diseases
214E10 Insulin-dependent diabetes mellitusIV. Metabolic Diseases
305F32 Depressive episodeV. Mental Disorders
505I21 Acute myocardial infarctionIX. Circulatory Diseases
603J45 AsthmaX. Respiratory Diseases
1269DeathDeath
\n", + "
" + ], + "text/plain": [ + " name ICD-10 Chapter (short)\n", + "46 A41 Other septicaemia I. Infectious Diseases\n", + "95 B01 Varicella [chickenpox] I. Infectious Diseases\n", + "1168 C25 Malignant neoplasm of pancreas II. Neoplasms\n", + "1188 C50 Malignant neoplasm of breast II. Neoplasms\n", + "374 G30 Alzheimer's disease VI. Nervous System Diseases\n", + "214 E10 Insulin-dependent diabetes mellitus IV. Metabolic Diseases\n", + "305 F32 Depressive episode V. Mental Disorders\n", + "505 I21 Acute myocardial infarction IX. Circulatory Diseases\n", + "603 J45 Asthma X. Respiratory Diseases\n", + "1269 Death Death" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Delphi is capable of predicting the disease risk for 1,256 diseases from ICD-10 plus death. \n", + "# For illustrative purposes, some of the plots will focus on a subset of 10 selected diseases - the same subset in used in the Delphi paper. \n", + "\n", + "diseases_of_interest = [46, 95, 1168, 1188, 374, 214, 305, 505, 603, 1269]\n", + "delphi_labels.iloc[diseases_of_interest][['name', 'ICD-10 Chapter (short)']]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f8d86352", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cpu\n", + "Model config: {'n_layer': 12, 'n_embd': 120, 'n_head': 12, 'max_epoch': 200, 'batch_size': 128, 'lr_initial': 0.0006, 'lr_final': 6e-05, 'weight_decay': 0.2, 'warmup_epochs': 10, 'early_stopping_patience': 10, 'pdrop': 0.0, 'token_pdrop': 0.0, 'betas': [0.9, 0.99]}\n", + "Model loaded from best_model_n_embd_120_n_layer_12_n_head_12.pt with 2.40M parameters.\n" + ] + } + ], + "source": [ + "seed = 1337\n", + "\n", + "torch.manual_seed(seed)\n", + "torch.cuda.manual_seed(seed)\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "print(f'Using device: {device}')\n", + "model = load_model('config_n_embd_120_n_layer_12_n_head_12.json',\n", + " 'best_model_n_embd_120_n_layer_12_n_head_12.pt', \n", + " 1270)\n", + "model.eval()\n", + "model = model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a8658fb6", + "metadata": {}, + "outputs": [], + "source": [ + "# Let's try to use the loaded model to extrapolate a partial health trajectory.\n", + "\n", + "example_health_trajectory = [\n", + " ('Male', 0),\n", + " ('B01 Varicella [chickenpox]',2),\n", + " ('L20 Atopic dermatitis',3),\n", + " ('No event', 5),\n", + " ('No event', 10),\n", + " ('No event', 15),\n", + " ('No event', 20),\n", + " ('G43 Migraine', 20),\n", + " ('E73 Lactose intolerance',21),\n", + " ('B27 Infectious mononucleosis',22),\n", + " ('No event', 25),\n", + " ('J11 Influenza, virus not identified',28),\n", + " ('No event', 30),\n", + " ('No event', 35),\n", + " ('No event', 40),\n", + " ('Smoking low', 41),\n", + " ('BMI mid', 41),\n", + " ('Alcohol low', 41),\n", + " ('No event', 42),\n", + "]\n", + "example_health_trajectory = [(a, b * 365.25) for a,b in example_health_trajectory] " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d7f2e4a1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input trajectory:\n", + "0.0: Male\n", + "2.0: B01 Varicella [chickenpox]\n", + "3.0: L20 Atopic dermatitis\n", + "5.0: No event\n", + "10.0: No event\n", + "15.0: No event\n", + "20.0: No event\n", + "20.0: G43 Migraine\n", + "21.0: E73 Lactose intolerance\n", + "22.0: B27 Infectious mononucleosis\n", + "25.0: No event\n", + "28.0: J11 Influenza, virus not identified\n", + "30.0: No event\n", + "35.0: No event\n", + "40.0: No event\n", + "41.0: Smoking low\n", + "41.0: BMI mid\n", + "41.0: Alcohol low\n", + "42.0: No event\n", + "=====================\n", + "Generated trajectory:\n", + "44.9: E03 Other hypothyroidism\n", + "45.2: M19 Other arthrosis\n", + "45.7: H60 Otitis externa\n", + "47.1: H10 Conjunctivitis\n", + "47.6: H66 Suppurative and unspecified otitis media\n", + "48.4: M25 Other joint disorders, not elsewhere classified\n", + "49.8: M79 Other soft tissue disorders, not elsewhere classified\n", + "50.9: M65 Synovitis and tenosynovitis\n", + "51.2: D50 Iron deficiency anaemia\n", + "51.9: K57 Diverticular disease of intestine\n", + "52.6: K63 Other diseases of intestine\n", + "52.8: K66 Other disorders of peritoneum\n", + "52.9: A09 Diarrhoea and gastro-enteritis of presumed infectious origin\n", + "53.8: K20 Oesophagitis\n", + "53.8: K29 Gastritis and duodenitis\n", + "53.9: L30 Other dermatitis\n", + "54.5: M54 Dorsalgia\n", + "54.8: M15 Polyarthrosis\n", + "56.1: N30 Cystitis\n", + "56.3: F07 Personality and behavioural disorders due to brain disease, damage and dysfunction\n", + "56.4: K80 Cholelithiasis\n", + "56.4: M13 Other arthritis\n", + "56.6: B37 Candidiasis\n", + "58.0: E11 Non-insulin-dependent diabetes mellitus\n", + "58.0: K59 Other functional intestinal disorders\n", + "60.5: L03 Cellulitis\n", + "60.8: K52 Other non-infective gastro-enteritis and colitis\n", + "61.4: K31 Other diseases of stomach and duodenum\n", + "61.8: K74 Fibrosis and cirrhosis of liver\n", + "61.8: Death\n" + ] + } + ], + "source": [ + "max_new_tokens = 100\n", + "\n", + "name_to_token_id = {row[1]['name']: row[1]['index'] for row in delphi_labels.iterrows()}\n", + "\n", + "events = [name_to_token_id[event[0]] for event in example_health_trajectory]\n", + "events = torch.tensor(events, device=device).unsqueeze(0)\n", + "ages = [event[1] for event in example_health_trajectory]\n", + "ages = torch.tensor(ages, device=device).unsqueeze(0)\n", + "\n", + "res = []\n", + "with torch.no_grad():\n", + " y,b,_ = model.generate(events, ages, max_new_tokens, termination_tokens=[1269])\n", + " # Convert model outputs to readable format\n", + " events_data = zip(y.cpu().numpy().flatten(), b.cpu().numpy().flatten()/365.)\n", + " \n", + " print('Input trajectory:')\n", + " for i, (event_id, age_years) in enumerate(events_data):\n", + " if i == len(example_health_trajectory):\n", + " print('=====================')\n", + " print('Generated trajectory:')\n", + " event_name = delphi_labels.loc[event_id, 'name']\n", + " print(f\"{age_years:2.1f}: {event_name}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b937d75", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/utils.py b/utils.py index a3954d9..bab1538 100644 --- a/utils.py +++ b/utils.py @@ -122,6 +122,8 @@ def load_model(config_path, model_path, vocab_size, device='cpu'): with open(config_path, 'r') as f: config_dict = json.load(f) + print(f"Model config: {config_dict}") + # Create a config object from the dictionary class AttrDict(dict): def __init__(self, *args, **kwargs):