feat: print model config and add evaluation notebook
This commit is contained in:
341
evaluate_models.ipynb
Normal file
341
evaluate_models.ipynb
Normal file
@@ -0,0 +1,341 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "08bd379a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"from models import TimeAwareGPT2\n",
|
||||
"from utils import load_model\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import textwrap\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"plt.rcParams['figure.facecolor'] = 'white'\n",
|
||||
"plt.rcParams.update({'axes.grid': True,\n",
|
||||
" 'grid.linestyle': ':',\n",
|
||||
" 'axes.spines.bottom': False,\n",
|
||||
" 'axes.spines.left': False,\n",
|
||||
" 'axes.spines.right': False,\n",
|
||||
" 'axes.spines.top': False})\n",
|
||||
"plt.rcParams['figure.dpi'] = 72\n",
|
||||
"plt.rcParams['pdf.fonttype'] = 42\n",
|
||||
"\n",
|
||||
"#Green\n",
|
||||
"light_male = '#BAEBE3'\n",
|
||||
"normal_male = '#0FB8A1'\n",
|
||||
"dark_male = '#00574A'\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"#Purple\n",
|
||||
"light_female = '#DEC7FF'\n",
|
||||
"normal_female = '#8520F1'\n",
|
||||
"dark_female = '#7A00BF'\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "1d8375ab",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>name</th>\n",
|
||||
" <th>ICD-10 Chapter (short)</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>46</th>\n",
|
||||
" <td>A41 Other septicaemia</td>\n",
|
||||
" <td>I. Infectious Diseases</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>95</th>\n",
|
||||
" <td>B01 Varicella [chickenpox]</td>\n",
|
||||
" <td>I. Infectious Diseases</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1168</th>\n",
|
||||
" <td>C25 Malignant neoplasm of pancreas</td>\n",
|
||||
" <td>II. Neoplasms</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1188</th>\n",
|
||||
" <td>C50 Malignant neoplasm of breast</td>\n",
|
||||
" <td>II. Neoplasms</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>374</th>\n",
|
||||
" <td>G30 Alzheimer's disease</td>\n",
|
||||
" <td>VI. Nervous System Diseases</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>214</th>\n",
|
||||
" <td>E10 Insulin-dependent diabetes mellitus</td>\n",
|
||||
" <td>IV. Metabolic Diseases</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>305</th>\n",
|
||||
" <td>F32 Depressive episode</td>\n",
|
||||
" <td>V. Mental Disorders</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>505</th>\n",
|
||||
" <td>I21 Acute myocardial infarction</td>\n",
|
||||
" <td>IX. Circulatory Diseases</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>603</th>\n",
|
||||
" <td>J45 Asthma</td>\n",
|
||||
" <td>X. Respiratory Diseases</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1269</th>\n",
|
||||
" <td>Death</td>\n",
|
||||
" <td>Death</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" name ICD-10 Chapter (short)\n",
|
||||
"46 A41 Other septicaemia I. Infectious Diseases\n",
|
||||
"95 B01 Varicella [chickenpox] I. Infectious Diseases\n",
|
||||
"1168 C25 Malignant neoplasm of pancreas II. Neoplasms\n",
|
||||
"1188 C50 Malignant neoplasm of breast II. Neoplasms\n",
|
||||
"374 G30 Alzheimer's disease VI. Nervous System Diseases\n",
|
||||
"214 E10 Insulin-dependent diabetes mellitus IV. Metabolic Diseases\n",
|
||||
"305 F32 Depressive episode V. Mental Disorders\n",
|
||||
"505 I21 Acute myocardial infarction IX. Circulatory Diseases\n",
|
||||
"603 J45 Asthma X. Respiratory Diseases\n",
|
||||
"1269 Death Death"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Delphi is capable of predicting the disease risk for 1,256 diseases from ICD-10 plus death. \n",
|
||||
"# For illustrative purposes, some of the plots will focus on a subset of 10 selected diseases - the same subset in used in the Delphi paper. \n",
|
||||
"\n",
|
||||
"diseases_of_interest = [46, 95, 1168, 1188, 374, 214, 305, 505, 603, 1269]\n",
|
||||
"delphi_labels.iloc[diseases_of_interest][['name', 'ICD-10 Chapter (short)']]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "f8d86352",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using device: cpu\n",
|
||||
"Model config: {'n_layer': 12, 'n_embd': 120, 'n_head': 12, 'max_epoch': 200, 'batch_size': 128, 'lr_initial': 0.0006, 'lr_final': 6e-05, 'weight_decay': 0.2, 'warmup_epochs': 10, 'early_stopping_patience': 10, 'pdrop': 0.0, 'token_pdrop': 0.0, 'betas': [0.9, 0.99]}\n",
|
||||
"Model loaded from best_model_n_embd_120_n_layer_12_n_head_12.pt with 2.40M parameters.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"seed = 1337\n",
|
||||
"\n",
|
||||
"torch.manual_seed(seed)\n",
|
||||
"torch.cuda.manual_seed(seed)\n",
|
||||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"print(f'Using device: {device}')\n",
|
||||
"model = load_model('config_n_embd_120_n_layer_12_n_head_12.json',\n",
|
||||
" 'best_model_n_embd_120_n_layer_12_n_head_12.pt', \n",
|
||||
" 1270)\n",
|
||||
"model.eval()\n",
|
||||
"model = model.to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "a8658fb6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Let's try to use the loaded model to extrapolate a partial health trajectory.\n",
|
||||
"\n",
|
||||
"example_health_trajectory = [\n",
|
||||
" ('Male', 0),\n",
|
||||
" ('B01 Varicella [chickenpox]',2),\n",
|
||||
" ('L20 Atopic dermatitis',3),\n",
|
||||
" ('No event', 5),\n",
|
||||
" ('No event', 10),\n",
|
||||
" ('No event', 15),\n",
|
||||
" ('No event', 20),\n",
|
||||
" ('G43 Migraine', 20),\n",
|
||||
" ('E73 Lactose intolerance',21),\n",
|
||||
" ('B27 Infectious mononucleosis',22),\n",
|
||||
" ('No event', 25),\n",
|
||||
" ('J11 Influenza, virus not identified',28),\n",
|
||||
" ('No event', 30),\n",
|
||||
" ('No event', 35),\n",
|
||||
" ('No event', 40),\n",
|
||||
" ('Smoking low', 41),\n",
|
||||
" ('BMI mid', 41),\n",
|
||||
" ('Alcohol low', 41),\n",
|
||||
" ('No event', 42),\n",
|
||||
"]\n",
|
||||
"example_health_trajectory = [(a, b * 365.25) for a,b in example_health_trajectory] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "d7f2e4a1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Input trajectory:\n",
|
||||
"0.0: Male\n",
|
||||
"2.0: B01 Varicella [chickenpox]\n",
|
||||
"3.0: L20 Atopic dermatitis\n",
|
||||
"5.0: No event\n",
|
||||
"10.0: No event\n",
|
||||
"15.0: No event\n",
|
||||
"20.0: No event\n",
|
||||
"20.0: G43 Migraine\n",
|
||||
"21.0: E73 Lactose intolerance\n",
|
||||
"22.0: B27 Infectious mononucleosis\n",
|
||||
"25.0: No event\n",
|
||||
"28.0: J11 Influenza, virus not identified\n",
|
||||
"30.0: No event\n",
|
||||
"35.0: No event\n",
|
||||
"40.0: No event\n",
|
||||
"41.0: Smoking low\n",
|
||||
"41.0: BMI mid\n",
|
||||
"41.0: Alcohol low\n",
|
||||
"42.0: No event\n",
|
||||
"=====================\n",
|
||||
"Generated trajectory:\n",
|
||||
"44.9: E03 Other hypothyroidism\n",
|
||||
"45.2: M19 Other arthrosis\n",
|
||||
"45.7: H60 Otitis externa\n",
|
||||
"47.1: H10 Conjunctivitis\n",
|
||||
"47.6: H66 Suppurative and unspecified otitis media\n",
|
||||
"48.4: M25 Other joint disorders, not elsewhere classified\n",
|
||||
"49.8: M79 Other soft tissue disorders, not elsewhere classified\n",
|
||||
"50.9: M65 Synovitis and tenosynovitis\n",
|
||||
"51.2: D50 Iron deficiency anaemia\n",
|
||||
"51.9: K57 Diverticular disease of intestine\n",
|
||||
"52.6: K63 Other diseases of intestine\n",
|
||||
"52.8: K66 Other disorders of peritoneum\n",
|
||||
"52.9: A09 Diarrhoea and gastro-enteritis of presumed infectious origin\n",
|
||||
"53.8: K20 Oesophagitis\n",
|
||||
"53.8: K29 Gastritis and duodenitis\n",
|
||||
"53.9: L30 Other dermatitis\n",
|
||||
"54.5: M54 Dorsalgia\n",
|
||||
"54.8: M15 Polyarthrosis\n",
|
||||
"56.1: N30 Cystitis\n",
|
||||
"56.3: F07 Personality and behavioural disorders due to brain disease, damage and dysfunction\n",
|
||||
"56.4: K80 Cholelithiasis\n",
|
||||
"56.4: M13 Other arthritis\n",
|
||||
"56.6: B37 Candidiasis\n",
|
||||
"58.0: E11 Non-insulin-dependent diabetes mellitus\n",
|
||||
"58.0: K59 Other functional intestinal disorders\n",
|
||||
"60.5: L03 Cellulitis\n",
|
||||
"60.8: K52 Other non-infective gastro-enteritis and colitis\n",
|
||||
"61.4: K31 Other diseases of stomach and duodenum\n",
|
||||
"61.8: K74 Fibrosis and cirrhosis of liver\n",
|
||||
"61.8: Death\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"max_new_tokens = 100\n",
|
||||
"\n",
|
||||
"name_to_token_id = {row[1]['name']: row[1]['index'] for row in delphi_labels.iterrows()}\n",
|
||||
"\n",
|
||||
"events = [name_to_token_id[event[0]] for event in example_health_trajectory]\n",
|
||||
"events = torch.tensor(events, device=device).unsqueeze(0)\n",
|
||||
"ages = [event[1] for event in example_health_trajectory]\n",
|
||||
"ages = torch.tensor(ages, device=device).unsqueeze(0)\n",
|
||||
"\n",
|
||||
"res = []\n",
|
||||
"with torch.no_grad():\n",
|
||||
" y,b,_ = model.generate(events, ages, max_new_tokens, termination_tokens=[1269])\n",
|
||||
" # Convert model outputs to readable format\n",
|
||||
" events_data = zip(y.cpu().numpy().flatten(), b.cpu().numpy().flatten()/365.)\n",
|
||||
" \n",
|
||||
" print('Input trajectory:')\n",
|
||||
" for i, (event_id, age_years) in enumerate(events_data):\n",
|
||||
" if i == len(example_health_trajectory):\n",
|
||||
" print('=====================')\n",
|
||||
" print('Generated trajectory:')\n",
|
||||
" event_name = delphi_labels.loc[event_id, 'name']\n",
|
||||
" print(f\"{age_years:2.1f}: {event_name}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0b937d75",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
2
utils.py
2
utils.py
@@ -122,6 +122,8 @@ def load_model(config_path, model_path, vocab_size, device='cpu'):
|
||||
with open(config_path, 'r') as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
print(f"Model config: {config_dict}")
|
||||
|
||||
# Create a config object from the dictionary
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
Reference in New Issue
Block a user