feat: print model config and add evaluation notebook

This commit is contained in:
2025-10-20 10:14:50 +08:00
parent 6b782b86e1
commit 1c9e2a2fb3
2 changed files with 343 additions and 0 deletions

341
evaluate_models.ipynb Normal file
View File

@@ -0,0 +1,341 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "08bd379a",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"from models import TimeAwareGPT2\n",
"from utils import load_model\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"import numpy as np\n",
"import textwrap\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rcParams['figure.facecolor'] = 'white'\n",
"plt.rcParams.update({'axes.grid': True,\n",
" 'grid.linestyle': ':',\n",
" 'axes.spines.bottom': False,\n",
" 'axes.spines.left': False,\n",
" 'axes.spines.right': False,\n",
" 'axes.spines.top': False})\n",
"plt.rcParams['figure.dpi'] = 72\n",
"plt.rcParams['pdf.fonttype'] = 42\n",
"\n",
"#Green\n",
"light_male = '#BAEBE3'\n",
"normal_male = '#0FB8A1'\n",
"dark_male = '#00574A'\n",
"\n",
"\n",
"#Purple\n",
"light_female = '#DEC7FF'\n",
"normal_female = '#8520F1'\n",
"dark_female = '#7A00BF'\n",
"\n",
" \n",
"delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1d8375ab",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>ICD-10 Chapter (short)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>46</th>\n",
" <td>A41 Other septicaemia</td>\n",
" <td>I. Infectious Diseases</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>B01 Varicella [chickenpox]</td>\n",
" <td>I. Infectious Diseases</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1168</th>\n",
" <td>C25 Malignant neoplasm of pancreas</td>\n",
" <td>II. Neoplasms</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1188</th>\n",
" <td>C50 Malignant neoplasm of breast</td>\n",
" <td>II. Neoplasms</td>\n",
" </tr>\n",
" <tr>\n",
" <th>374</th>\n",
" <td>G30 Alzheimer's disease</td>\n",
" <td>VI. Nervous System Diseases</td>\n",
" </tr>\n",
" <tr>\n",
" <th>214</th>\n",
" <td>E10 Insulin-dependent diabetes mellitus</td>\n",
" <td>IV. Metabolic Diseases</td>\n",
" </tr>\n",
" <tr>\n",
" <th>305</th>\n",
" <td>F32 Depressive episode</td>\n",
" <td>V. Mental Disorders</td>\n",
" </tr>\n",
" <tr>\n",
" <th>505</th>\n",
" <td>I21 Acute myocardial infarction</td>\n",
" <td>IX. Circulatory Diseases</td>\n",
" </tr>\n",
" <tr>\n",
" <th>603</th>\n",
" <td>J45 Asthma</td>\n",
" <td>X. Respiratory Diseases</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1269</th>\n",
" <td>Death</td>\n",
" <td>Death</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name ICD-10 Chapter (short)\n",
"46 A41 Other septicaemia I. Infectious Diseases\n",
"95 B01 Varicella [chickenpox] I. Infectious Diseases\n",
"1168 C25 Malignant neoplasm of pancreas II. Neoplasms\n",
"1188 C50 Malignant neoplasm of breast II. Neoplasms\n",
"374 G30 Alzheimer's disease VI. Nervous System Diseases\n",
"214 E10 Insulin-dependent diabetes mellitus IV. Metabolic Diseases\n",
"305 F32 Depressive episode V. Mental Disorders\n",
"505 I21 Acute myocardial infarction IX. Circulatory Diseases\n",
"603 J45 Asthma X. Respiratory Diseases\n",
"1269 Death Death"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Delphi is capable of predicting the disease risk for 1,256 diseases from ICD-10 plus death. \n",
"# For illustrative purposes, some of the plots will focus on a subset of 10 selected diseases - the same subset in used in the Delphi paper. \n",
"\n",
"diseases_of_interest = [46, 95, 1168, 1188, 374, 214, 305, 505, 603, 1269]\n",
"delphi_labels.iloc[diseases_of_interest][['name', 'ICD-10 Chapter (short)']]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f8d86352",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cpu\n",
"Model config: {'n_layer': 12, 'n_embd': 120, 'n_head': 12, 'max_epoch': 200, 'batch_size': 128, 'lr_initial': 0.0006, 'lr_final': 6e-05, 'weight_decay': 0.2, 'warmup_epochs': 10, 'early_stopping_patience': 10, 'pdrop': 0.0, 'token_pdrop': 0.0, 'betas': [0.9, 0.99]}\n",
"Model loaded from best_model_n_embd_120_n_layer_12_n_head_12.pt with 2.40M parameters.\n"
]
}
],
"source": [
"seed = 1337\n",
"\n",
"torch.manual_seed(seed)\n",
"torch.cuda.manual_seed(seed)\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f'Using device: {device}')\n",
"model = load_model('config_n_embd_120_n_layer_12_n_head_12.json',\n",
" 'best_model_n_embd_120_n_layer_12_n_head_12.pt', \n",
" 1270)\n",
"model.eval()\n",
"model = model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a8658fb6",
"metadata": {},
"outputs": [],
"source": [
"# Let's try to use the loaded model to extrapolate a partial health trajectory.\n",
"\n",
"example_health_trajectory = [\n",
" ('Male', 0),\n",
" ('B01 Varicella [chickenpox]',2),\n",
" ('L20 Atopic dermatitis',3),\n",
" ('No event', 5),\n",
" ('No event', 10),\n",
" ('No event', 15),\n",
" ('No event', 20),\n",
" ('G43 Migraine', 20),\n",
" ('E73 Lactose intolerance',21),\n",
" ('B27 Infectious mononucleosis',22),\n",
" ('No event', 25),\n",
" ('J11 Influenza, virus not identified',28),\n",
" ('No event', 30),\n",
" ('No event', 35),\n",
" ('No event', 40),\n",
" ('Smoking low', 41),\n",
" ('BMI mid', 41),\n",
" ('Alcohol low', 41),\n",
" ('No event', 42),\n",
"]\n",
"example_health_trajectory = [(a, b * 365.25) for a,b in example_health_trajectory] "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d7f2e4a1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input trajectory:\n",
"0.0: Male\n",
"2.0: B01 Varicella [chickenpox]\n",
"3.0: L20 Atopic dermatitis\n",
"5.0: No event\n",
"10.0: No event\n",
"15.0: No event\n",
"20.0: No event\n",
"20.0: G43 Migraine\n",
"21.0: E73 Lactose intolerance\n",
"22.0: B27 Infectious mononucleosis\n",
"25.0: No event\n",
"28.0: J11 Influenza, virus not identified\n",
"30.0: No event\n",
"35.0: No event\n",
"40.0: No event\n",
"41.0: Smoking low\n",
"41.0: BMI mid\n",
"41.0: Alcohol low\n",
"42.0: No event\n",
"=====================\n",
"Generated trajectory:\n",
"44.9: E03 Other hypothyroidism\n",
"45.2: M19 Other arthrosis\n",
"45.7: H60 Otitis externa\n",
"47.1: H10 Conjunctivitis\n",
"47.6: H66 Suppurative and unspecified otitis media\n",
"48.4: M25 Other joint disorders, not elsewhere classified\n",
"49.8: M79 Other soft tissue disorders, not elsewhere classified\n",
"50.9: M65 Synovitis and tenosynovitis\n",
"51.2: D50 Iron deficiency anaemia\n",
"51.9: K57 Diverticular disease of intestine\n",
"52.6: K63 Other diseases of intestine\n",
"52.8: K66 Other disorders of peritoneum\n",
"52.9: A09 Diarrhoea and gastro-enteritis of presumed infectious origin\n",
"53.8: K20 Oesophagitis\n",
"53.8: K29 Gastritis and duodenitis\n",
"53.9: L30 Other dermatitis\n",
"54.5: M54 Dorsalgia\n",
"54.8: M15 Polyarthrosis\n",
"56.1: N30 Cystitis\n",
"56.3: F07 Personality and behavioural disorders due to brain disease, damage and dysfunction\n",
"56.4: K80 Cholelithiasis\n",
"56.4: M13 Other arthritis\n",
"56.6: B37 Candidiasis\n",
"58.0: E11 Non-insulin-dependent diabetes mellitus\n",
"58.0: K59 Other functional intestinal disorders\n",
"60.5: L03 Cellulitis\n",
"60.8: K52 Other non-infective gastro-enteritis and colitis\n",
"61.4: K31 Other diseases of stomach and duodenum\n",
"61.8: K74 Fibrosis and cirrhosis of liver\n",
"61.8: Death\n"
]
}
],
"source": [
"max_new_tokens = 100\n",
"\n",
"name_to_token_id = {row[1]['name']: row[1]['index'] for row in delphi_labels.iterrows()}\n",
"\n",
"events = [name_to_token_id[event[0]] for event in example_health_trajectory]\n",
"events = torch.tensor(events, device=device).unsqueeze(0)\n",
"ages = [event[1] for event in example_health_trajectory]\n",
"ages = torch.tensor(ages, device=device).unsqueeze(0)\n",
"\n",
"res = []\n",
"with torch.no_grad():\n",
" y,b,_ = model.generate(events, ages, max_new_tokens, termination_tokens=[1269])\n",
" # Convert model outputs to readable format\n",
" events_data = zip(y.cpu().numpy().flatten(), b.cpu().numpy().flatten()/365.)\n",
" \n",
" print('Input trajectory:')\n",
" for i, (event_id, age_years) in enumerate(events_data):\n",
" if i == len(example_health_trajectory):\n",
" print('=====================')\n",
" print('Generated trajectory:')\n",
" event_name = delphi_labels.loc[event_id, 'name']\n",
" print(f\"{age_years:2.1f}: {event_name}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b937d75",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -122,6 +122,8 @@ def load_model(config_path, model_path, vocab_size, device='cpu'):
with open(config_path, 'r') as f: with open(config_path, 'r') as f:
config_dict = json.load(f) config_dict = json.load(f)
print(f"Model config: {config_dict}")
# Create a config object from the dictionary # Create a config object from the dictionary
class AttrDict(dict): class AttrDict(dict):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):