Files
DeepHealth/evaluate_models.ipynb

553 lines
176 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "08bd379a",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from models import TimeAwareGPT2\n",
"from utils import load_model\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"import numpy as np\n",
"import textwrap\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rcParams['figure.facecolor'] = 'white'\n",
"plt.rcParams.update({'axes.grid': True,\n",
" 'grid.linestyle': ':',\n",
" 'axes.spines.bottom': False,\n",
" 'axes.spines.left': False,\n",
" 'axes.spines.right': False,\n",
" 'axes.spines.top': False})\n",
"plt.rcParams['figure.dpi'] = 72\n",
"plt.rcParams['pdf.fonttype'] = 42\n",
"\n",
"#Green\n",
"light_male = '#BAEBE3'\n",
"normal_male = '#0FB8A1'\n",
"dark_male = '#00574A'\n",
"\n",
"\n",
"#Purple\n",
"light_female = '#DEC7FF'\n",
"normal_female = '#8520F1'\n",
"dark_female = '#7A00BF'\n",
"\n",
" \n",
"delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1d8375ab",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>ICD-10 Chapter (short)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>46</th>\n",
" <td>A41 Other septicaemia</td>\n",
" <td>I. Infectious Diseases</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>B01 Varicella [chickenpox]</td>\n",
" <td>I. Infectious Diseases</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1168</th>\n",
" <td>C25 Malignant neoplasm of pancreas</td>\n",
" <td>II. Neoplasms</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1188</th>\n",
" <td>C50 Malignant neoplasm of breast</td>\n",
" <td>II. Neoplasms</td>\n",
" </tr>\n",
" <tr>\n",
" <th>374</th>\n",
" <td>G30 Alzheimer's disease</td>\n",
" <td>VI. Nervous System Diseases</td>\n",
" </tr>\n",
" <tr>\n",
" <th>214</th>\n",
" <td>E10 Insulin-dependent diabetes mellitus</td>\n",
" <td>IV. Metabolic Diseases</td>\n",
" </tr>\n",
" <tr>\n",
" <th>305</th>\n",
" <td>F32 Depressive episode</td>\n",
" <td>V. Mental Disorders</td>\n",
" </tr>\n",
" <tr>\n",
" <th>505</th>\n",
" <td>I21 Acute myocardial infarction</td>\n",
" <td>IX. Circulatory Diseases</td>\n",
" </tr>\n",
" <tr>\n",
" <th>603</th>\n",
" <td>J45 Asthma</td>\n",
" <td>X. Respiratory Diseases</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1269</th>\n",
" <td>Death</td>\n",
" <td>Death</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name ICD-10 Chapter (short)\n",
"46 A41 Other septicaemia I. Infectious Diseases\n",
"95 B01 Varicella [chickenpox] I. Infectious Diseases\n",
"1168 C25 Malignant neoplasm of pancreas II. Neoplasms\n",
"1188 C50 Malignant neoplasm of breast II. Neoplasms\n",
"374 G30 Alzheimer's disease VI. Nervous System Diseases\n",
"214 E10 Insulin-dependent diabetes mellitus IV. Metabolic Diseases\n",
"305 F32 Depressive episode V. Mental Disorders\n",
"505 I21 Acute myocardial infarction IX. Circulatory Diseases\n",
"603 J45 Asthma X. Respiratory Diseases\n",
"1269 Death Death"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Delphi is capable of predicting the disease risk for 1,256 diseases from ICD-10 plus death. \n",
"# For illustrative purposes, some of the plots will focus on a subset of 10 selected diseases - the same subset in used in the Delphi paper. \n",
"\n",
"diseases_of_interest = [46, 95, 1168, 1188, 374, 214, 305, 505, 603, 1269]\n",
"delphi_labels.iloc[diseases_of_interest][['name', 'ICD-10 Chapter (short)']]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f8d86352",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cpu\n",
2025-10-21 09:20:43 +08:00
"Model config: {'n_layer': 12, 'n_embd': 120, 'n_head': 12, 'max_epoch': 200, 'batch_size': 128, 'lr_initial': 0.0006, 'lr_final': 6e-05, 'weight_decay': 0.2, 'warmup_epochs': 10, 'early_stopping_patience': 10, 'pdrop': 0.0, 'token_pdrop': 0.0, 'betas': [0.9, 0.99]}\n",
"Model loaded from best_model_n_embd_120_n_layer_12_n_head_12.pt with 2.40M parameters.\n"
]
}
],
"source": [
"seed = 1337\n",
2025-10-21 09:20:43 +08:00
"model_name = 'n_embd_120_n_layer_12_n_head_12'\n",
"torch.manual_seed(seed)\n",
"torch.cuda.manual_seed(seed)\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(f'Using device: {device}')\n",
2025-10-21 09:20:43 +08:00
"model = load_model(f'config_{model_name}.json',\n",
" f'best_model_{model_name}.pt', \n",
" 1270)\n",
"model.eval()\n",
"model = model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a8658fb6",
"metadata": {},
"outputs": [],
"source": [
"# Let's try to use the loaded model to extrapolate a partial health trajectory.\n",
"\n",
"example_health_trajectory = [\n",
" ('Male', 0),\n",
" ('B01 Varicella [chickenpox]',2),\n",
" ('L20 Atopic dermatitis',3),\n",
" ('No event', 5),\n",
" ('No event', 10),\n",
" ('No event', 15),\n",
" ('No event', 20),\n",
" ('G43 Migraine', 20),\n",
" ('E73 Lactose intolerance',21),\n",
" ('B27 Infectious mononucleosis',22),\n",
" ('No event', 25),\n",
" ('J11 Influenza, virus not identified',28),\n",
" ('No event', 30),\n",
" ('No event', 35),\n",
" ('No event', 40),\n",
" ('Smoking low', 41),\n",
" ('BMI mid', 41),\n",
" ('Alcohol low', 41),\n",
" ('No event', 42),\n",
"]\n",
"example_health_trajectory = [(a, b * 365.25) for a,b in example_health_trajectory] "
]
},
{
"cell_type": "code",
2025-10-20 13:47:50 +08:00
"execution_count": 5,
"id": "d7f2e4a1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input trajectory:\n",
"0.0: Male\n",
"2.0: B01 Varicella [chickenpox]\n",
"3.0: L20 Atopic dermatitis\n",
"5.0: No event\n",
"10.0: No event\n",
"15.0: No event\n",
"20.0: No event\n",
"20.0: G43 Migraine\n",
"21.0: E73 Lactose intolerance\n",
"22.0: B27 Infectious mononucleosis\n",
"25.0: No event\n",
"28.0: J11 Influenza, virus not identified\n",
"30.0: No event\n",
"35.0: No event\n",
"40.0: No event\n",
"41.0: Smoking low\n",
"41.0: BMI mid\n",
"41.0: Alcohol low\n",
"42.0: No event\n",
"=====================\n",
"Generated trajectory:\n",
2025-10-21 09:20:43 +08:00
"44.6: H35 Other retinal disorders\n",
"46.0: D50 Iron deficiency anaemia\n",
"47.1: No event\n",
"50.7: E55 Vitamin d deficiency\n",
"50.7: I10 Essential primary hypertension\n",
"50.8: D64 Other anaemias\n",
"50.8: H54 Blindness and low vision\n",
"50.8: K76 Other diseases of liver\n",
"51.1: E83 Disorders of mineral metabolism\n",
"51.2: N95 Menopausal and other perimenopausal disorders\n",
"51.2: M79 Other soft tissue disorders, not elsewhere classified\n",
"51.3: K91 Postprocedural disorders of digestive system, not elsewhere classified\n",
"51.3: B37 Candidiasis\n",
"51.4: E22 Hyperfunction of pituitary gland\n",
"51.6: N85 Other noninflammatory disorders of uterus, except cervix\n",
"51.7: G93 Other disorders of brain\n",
"51.7: K59 Other functional intestinal disorders\n",
"51.9: B96 Other bacterial agents as the cause of diseases classified to other chapters\n",
"51.9: E03 Other hypothyroidism\n",
2025-10-20 13:47:50 +08:00
"52.1: F32 Depressive episode\n",
2025-10-21 09:20:43 +08:00
"52.2: I47 Paroxysmal tachycardia\n",
"52.2: J90 Pleural effusion, not elsewhere classified\n",
"52.4: J96 Respiratory failure, not elsewhere classified\n",
"52.4: J18 Pneumonia, organism unspecified\n",
"52.4: E86 Volume depletion\n",
"52.4: G31 Other degenerative diseases of nervous system, not elsewhere classified\n",
"52.5: B59 Pneumocystosis\n",
"52.6: Death\n"
]
}
],
"source": [
"max_new_tokens = 100\n",
"\n",
"name_to_token_id = {row[1]['name']: row[1]['index'] for row in delphi_labels.iterrows()}\n",
"\n",
"events = [name_to_token_id[event[0]] for event in example_health_trajectory]\n",
"events = torch.tensor(events, device=device).unsqueeze(0)\n",
"ages = [event[1] for event in example_health_trajectory]\n",
"ages = torch.tensor(ages, device=device).unsqueeze(0)\n",
"\n",
"res = []\n",
"with torch.no_grad():\n",
" y,b,_ = model.generate(events, ages, max_new_tokens, termination_tokens=[1269])\n",
" # Convert model outputs to readable format\n",
" events_data = zip(y.cpu().numpy().flatten(), b.cpu().numpy().flatten()/365.)\n",
" \n",
" print('Input trajectory:')\n",
" for i, (event_id, age_years) in enumerate(events_data):\n",
" if i == len(example_health_trajectory):\n",
" print('=====================')\n",
" print('Generated trajectory:')\n",
" event_name = delphi_labels.loc[event_id, 'name']\n",
" print(f\"{age_years:2.1f}: {event_name}\")"
]
},
{
"cell_type": "code",
2025-10-20 13:47:50 +08:00
"execution_count": 6,
"id": "0b937d75",
"metadata": {},
"outputs": [],
2025-10-20 13:47:50 +08:00
"source": [
"from utils import PatientEventDataset, get_batch\n",
"train_data_path = 'ukb_real_train.bin'\n",
"val_data_path = 'ukb_real_val.bin'\n",
"train_data_arr = np.memmap(train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)\n",
"val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)\n",
"block_length = 128\n",
"train_dataset = PatientEventDataset(train_data_arr, block_length)\n",
"val_dataset = PatientEventDataset(val_data_arr, block_length)\n",
"dataset_subset_size = 2048"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6a3bb2dc",
"metadata": {},
"outputs": [],
"source": [
"input_events, input_times, target_events, target_times = get_batch(val_dataset, slice(256))\n",
"with torch.no_grad():\n",
" p = model(input_events.to(device), input_times.to(device))\n",
" p = p.cpu().detach().numpy().squeeze()\n",
"t = (target_times-input_times).cpu().numpy().squeeze()\n",
"target_events = target_events.cpu().numpy().squeeze()\n",
"ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]\n",
"mask = ~np.isin(target_events, ignored_token_ids)\n",
"p = p[mask]\n",
"t = t[mask]"
]
},
{
"cell_type": "code",
2025-10-21 09:20:43 +08:00
"execution_count": null,
2025-10-20 13:47:50 +08:00
"id": "1c32bd6e",
"metadata": {},
"outputs": [
{
"data": {
2025-10-21 09:20:43 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ4AAAELCAYAAAAofGgWAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAALEwAACxMBAJqcGAAAUttJREFUeJztXQe4FEW2PgTJ0QBiQgUEyYogiCwoKq4R4xpQQESeGBCVFRVRMb7FnHksplUxra4YUAQBQUAJRlBEUBQXBCRzycz7/oIz1C2quqtnunt65tb/ffe7Mz3dXdWnu06ffEqlUqkUOTg4OARA6SA7Ozg4OACOcTg4OASGYxwODg6B4RiHg4NDYDjG4eDgUHiM48cff8z1FBwcHPKNcQBbtmzJ9RQSQwdHi11w9MgdLUrlQxzHxo0bqUKFClTSAToAjhbe9CgqKqL169dT5cqVqVKlStpjsc+yZcvEYttzzz3FH29fuXKl+FyzZk1xPO+7YcMG2r59O5UqVYqqVKlCpUuXpq1bt1LZsmXF9nXr1hEvJ/yXlxaOKV++fHrOYQPzADCXbLBkyRKaNGkSXXHFFenr1yG7UWJCHvC2WODo4E8PLPKff/5Z/IbFesghh+z28GOfBQsWpL//97//TTMgPhYAA9lvv/3Sv8tYvXp14LlGxTTCApjGLbfcQnfeeae4dvw1bdo0f1UVBwdbQNKQ3/r4rttHxZo1a4odK28vCdi4cSPdeuutgmkceOCBvvvnhcTh4GALqCeQNFjiwHfdPiogbWzevHm37dWqVRMqSCHigx/X0huzV1OVcqWpZoUy1LDbHTR2WRWquW6N+N7xkN3plFc2DtzQcuXKFdsG3XTRokWJF//CBN8qLIigwMI44IADaI899qBCAS909dkIauOAveLPP/9M07dq1arCVmCycQD8H4DtYtOmTZRLbNu2TfwvU6aM1f5fL9lAt41bSvX2LCeYxMqN22jVhm3i/9adlzZn8HFG+uUF48DNVR946KK4wXvttVdGCykfwQ8rjHJBgFuMhbF27Vqh8xcK2IuQLTMEU/jjjz/S32vXrk377LPPbvsxQyoqKhK0ZIBx6aQVbAdTwbmg8ixfvrzY72BMbIj1A5igTsXKhHEsW7+V+n2wmCps/JPabfmSel/WM/2bkNTKV6YKNfah+rWq5reqInN3BiSNgw8+uMQwjWwAGoHBYoEUEnTPRRjqDRbf77//Ln6TpQ7ZcCpDxzR4O/7AZHAekyfEBl5MIwi2bEvRfZ8uo6KVS6nU5Cfo5HuGFPsd11+zcnna34Np5DXjABzTsEch0iosxgHGAEkMixNMY/HixcU8K4ceeqjWcBoEOiYB28nee++9mySSCWznNmz6Cpq7tIgqT36K7r9niNYQamKEecc4HEoGbGwTUQHj4Q9SmboIeU4slWQCWbVh4Fy45rgw5qe19OFP6+i85nvSRRc+vZttKIh04xhHwgBd1dbAVUiwib+Ig1HpGATuB/arU6eOuD/btm1LSwnYt0aNGmIfk+SA39gGoZtDHJj35yZ6+osV1HLfCnRJixpUpnR2EmhexHEENQbGia5du1KrVq2oSZMm9H//93/0zDPP0IABA9K/P//883T11VeLzy+99BK1adOGWrZsSX369Ek/TLDq33DDDdSiRQuaOnUqDRkyhFq3bi2CbxDBxw/x9OnTxbH4wxgcnIPz4DuOad68OQ0bNozyCVg8S5cu9Y2/0D0XmT4bzKhgFMV/XsCstsAmgT8EgEF1wX74D8ayTWICmCve3Pvuu69QaXAvVZiYRpgAAzOpo6s3bqM73p1D66eOpAHH7u3LNHTu6ryUOPwejjvfnU1z/htuoE7j/arR7ac38d3v2WefFeHKcNVh4Y4bN47at29PQ4cOFb+/9tprIrDm+++/F58/++wz4QXo27cvvfzyy3TppZeKRXL00UfTgw8+uGPsxo1p8ODB4vMll1xC7733Hp1++unUq1cvwRRw/oEDB6bnMGLECKpevbpgLHAL4veTTjopLzwoOqOjKf4izBeKLlCMpQ5WWwBZdcF/RJFuVEIAwBiwH+actLiPbdtTNOTdOTTvjaH04H13UfUK/tKsjdE2LxhHkj3Gjz32GL399tvi82+//SYWAd4806ZNowYNGtAPP/wgFvKTTz5JM2fOFMwFAKOpVatWWpQ955xz0uccP348/eMf/xCLasWKFUKa6dChg9CT27VrJ/a56KKLBEMBxowZQ9988w29+eab6XDoefPm5QXjUI2OeGODLjZqSjbPhU2gGKCqjbq4oeWK2pIp2HsTFranUvTwZ8to0ssP03UDB1OH5vWsjrOJSckLxuEn6tlIBlFgwoQJNHbsWKFe4KZ36tRJPFgXXHABvf7669SoUSM666yz0g9o9+7d6b777tMGZ/EDiuMhjcyYMUNYvO+44w7fIDec+/HHH6cuXbpQvkFdwLZMI1sVQPakyDYO1e4RZIyUFKCXCVNDzAeeA50h1XZs+fuTn6+gCQs30HV3/IMuark3hYnkGg/yAHizs58fkgWkDADM4p133qGRI0cKJgJ07txZSATQ5QFIEgsXLtztnMwk4KaDyMtSBN5kCHj7/PPPxfdXX301fQwYxtNPP50OiEINk7D8/lGCFymMjgi6isIgylGfujc5xuJAL+yDe6LaPZixmVBO8kyIGIiaNUXMDLabvBaAzgAOFSETpqECTOP/Zqykj35aR+c3rRaYaYCBFYTEkVScfPLJwhh6+OGHU8OGDalt27ZiOx4ebJszZ44whrLd4u677xa2B8QfwM4B9aVu3brFzgkG0bt3b2H4hMGNVRtg+PDhwqgK3b5jx47CrgFcfvnl9Msvv9CRRx4pHhoshv/85z9U0r0otpmyusAufAczwf5gbFjQukW9WQp7x4LTRYjqEJXBVEifY+fSayNfoR59rxMelKCwuQ85DzmH0XD06NEizRn2Ap3BCzqXygVxHBZnSQIeSq4Dcf/99wsr/6OPPmp9fJJoZhvmbaOLm96QNmOo+5iQTQyHLbIZgw2awz+dTyMevJPOuXIgDTy1WUaBf7DR+TGPyFQVWJ/xBoT+Lltp+/fvLwx9/fr1E9/xIHPMvukik+yOjRPvv/++oCmkERRbGTRoEOUrZBXA1ouiqh1+7libMdR9TIjj/ZrKYgzM/c05a+j5px6hM6+4iW7KkGlAcrKROCJTVeCihGsS+j5j1qxZQm/HQ3/llVcK9yFE8Z49ewqi4Y3K4jcD+jpERWY+nEuA/cMKOc4XnH/++eJPfiCC0AA0gzeHFxy+yyKzvBBlZs80B7C//IBzxSl1O/bHcZifPEceA8fhvoIJ4EHF/rDRqHPisbEf1DFWO2A45spfprHxH88hnjlIahhTHQP77b///oIu7NXA/pBi5AQ0+fzVq1cXDEcOTZfnq+4vb1d/M223OZfsJn77+7X00jdr6exr76Dr2u5J27dto+0ZjA37DOxsoJ1X8mBkjAM3VS3pBuPhiSeeKD6fcMIJwhsBAyMYCh6Kbt26ac/F5dkccgcsKPzp7mumqFixovizHV9eKFjo/EyYjJDYhwPLcDzG0kkdvF0ORMNnqDbMZFS7xYYNGygJWLRmCw2bsYq+XbqZ2h5Qgfq324syDQoF47R1J8e6GletWiX0J+bas2fPFgwEfyYcdthh4iapDxhHyhVi8lbYafUAv6ltF6r8tsEiQu0TLyNj1HU+8LzA68FzgHrL997EyCBdyN4LPH+qKI5rg6QLxgFDp7w/PmM77G/yi2v9+vXFpLA4gXngRbpp63Z6/bu19Nq0n2j11Ndp4M230on1KlG5PTJf0rh+ZAXXr1/fd99YjQe4+VyKDf9tuJupPQIeFrn4ioN/PY5MJQWbcnxhQ7VncNxFELet6kqFCiKHl6th56ziyDYR07WmLJ47MNOw7XMY9/uVRFe9t5henvITbfr4UXp6UF/qUr8ylQ7hJQo1BYGMiZI4EPWIkGno6Qic6tGjh9VxOqkC1azwFiy0GhNRVwCLMsoyLECy4ALBsoTDf8xU8Pb1kqCY2UD94DBwObxcZYhQSbC/bNvIRqr
2025-10-20 13:47:50 +08:00
"text/plain": [
"<Figure size 288x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from scipy.special import logsumexp\n",
"\n",
"# Calculate expected waiting times from model predictions using competing exponentials theory\n",
"# In Delphi's framework, each possible event has an exponential distribution with rate λᵢ = exp(logits[i])\n",
"# The expected time until any event occurs is 1/sum(λᵢ) = 1/exp(logsumexp(logits))\n",
"# logsumexp provides numerical stability vs. calculating exp(logits) directly\n",
"\n",
"# Let's see how the predicted waiting times compare to the observed waiting times\n",
"\n",
"plt.figure(figsize=(4, 4))\n",
"# Calculate expected time to next token (inverse of hazard rate)\n",
"expected_t = 1/np.exp(logsumexp(p, axis=-1))\n",
"\n",
"# Define bin width for logarithmic binning\n",
"delta_log_t = 0.1\n",
"log_range = np.arange(1.75, 4, delta_log_t)\n",
"\n",
"# Calculate average observed time for each logarithmic bin\n",
"observed_t = []\n",
"for i in log_range:\n",
" # Create mask for current bin and valid times\n",
" bin_mask = (expected_t > 10**i) & (expected_t <= 10**(i+delta_log_t)) & (t > 0)\n",
" # Calculate mean for this bin\n",
" bin_mean = t[bin_mask].mean() if bin_mask.sum() > 0 else np.nan\n",
" observed_t.append(bin_mean)\n",
"plt.axes().set_aspect('equal')\n",
"plt.scatter(expected_t, t+0.5, marker=\".\", c='lightgrey', rasterized=True)\n",
"plt.xlabel('Expected days to next token')\n",
"plt.ylabel('Observed days to next token')\n",
"plt.plot(10**(np.arange(1.75,4,delta_log_t)+delta_log_t/2.),observed_t, label='average')\n",
"plt.yscale('log')\n",
"plt.xscale('log')\n",
"plt.legend()\n",
"plt.xlim(1,2e3)\n",
"plt.ylim(1,2e3)\n",
"plt.plot([0,1],[0,1], transform = plt.gca().transAxes, c='k' , ls=(0, (5, 5)), linewidth=0.7)\n",
"\n",
"plt.gca().tick_params(length=1.15, width=0.3, labelsize=8, grid_alpha=1, grid_linewidth=0.45, grid_linestyle=':')\n",
2025-10-21 09:20:43 +08:00
"plt.gca().tick_params(length=1.15, width=0.3, labelsize=8, grid_alpha=0.0, grid_linewidth=0.35, which='minor')\n",
"plt.savefig(f'results_{model_name}/fig_expected_vs_observed_waiting_times.png', dpi=600)"
2025-10-20 13:47:50 +08:00
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e5e3c8f",
"metadata": {},
2025-10-21 09:20:43 +08:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbkAAAEeCAYAAAAXTWt+AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAALEwAACxMBAJqcGAAA1sxJREFUeJzsXQd4HNXVPbs72/uudtV7sSTbcu+FYhuwKcZU03sPLSQ/KUAIIaSQQCAJEDqhdzDNVFMM7t2SLcvqXdre+/zffSOtiiUjgwHH1vG3n7W7s7sz82befffec88V8TzPYxSjGMUoRjGKwxDin3oHRjGKUYxiFKP4oTBq5EYxilGMYhSHLUaN3ChGMYpRjOKwxaiRG8UoRjGKURy2GDVyoxjFKEYxisMWo0ZuFKMYxShGcdhi1MiN4ojH008/jblz5/5kv//www8jNTUVGo0Gdrv9e31XQ0MDRCIRYrEYfgp89dVXGDNmDH4qNDU1sfMYj8cP6raj+N/FqJE7QnD00UfDaDQiHA7v8/rjjz8+4LXPP/8cWVlZyedUSvnggw9i3LhxUKvV7L0zzzwTO3bs+NH2/3BFNBrFz3/+c3z00Ufw+Xwwm82HlNE6UMybNw/V1dU/2WIjJyeHnUeJRHJQt/2xkJeXh08++eSn3o3DCqNG7ggATZS0wqbJcsWKFQf8+RtvvBEPPPAAM3QOhwN79uzBqaeeivfee+8H2d//ZRyoMers7EQoFMLYsWN/sH06nDDqdY3iQDFq5I4A/Pe//8XMmTNx8cUX45lnnjmgz9bU1ODf//43XnzxRRx77LGQy+VQqVQ477zz8Ktf/Wqf7V9++WVMnTp1wGv3338/TjnlFPb3+++/j/Lycmi1WmRmZuJvf/vbflf1v/jFL5gHmp+fjw8++GDYFe+dd96J888/f4D389RTTyE7O5t9/pFHHsGGDRtQUVEBg8GAn/3sZwN+j7xVek2v16O0tBSffvpp8j23243LLrsM6enpbJ9vu+225GRL+zlnzhzcfPPNzAuj/RgM8p5vuukmZGRksAf9Ta/RYqE3tEf7ROd3MObPn598n0Jra9asQSKRwN13343c3FxYrVZceOGFbB+Hwuuvv87O1c6dO9nn/vznP6OwsJDt61lnncUWLf3PGV0f5OGkpKTgj3/8Y/J71q9fz8ZVp9Ox0Cp5n0NhcBSAfpvGmM47nduzzz6bGfXB2LVrF66++mp2fHScdLwEumavueYaLFmyhEURVq1axRZXkyZNYvtC49v/nA/2fClScfvtt7MxomvuuOOOg81mO+Bte+8jOud07v7whz/s1+va33X+7rvvYuLEiewYZ8+eje3bt7PXL7jgAhZCPfnkk9k5+Otf/zrkd4/iAEGyXqM4vFFYWMj/+9//5jdu3MhzHMd3dHQk3zvqqKP4xx57bMD2q1at4jMzM9nfDz/8MJ+TkzPi3/L7/bxGo+H37NmTfG3q1Kn8iy++yP5OS0vjv/zyS/a3w+HgN23aNOT3PPXUU2xfH330UT4Wi/EPPfQQn56ezicSCfZ+bm4u//HHHye3/93vfsefd9557O/6+nqSquOvuuoqPhgM8h9++CEvl8v5pUuX8p2dnXxLSwtvsVj4zz//PPlbEomEv++++/hIJMK/9NJLvE6n4+12O3v/1FNP5a+88kre5/Oxz0+bNo1/5JFHBnz2wQcf5KPRKB8IBPY5lttvv52fMWMG+2xXVxc/a9Ys/rbbbhuwr/TZoTDU+0888QQb09raWt7r9fLLli3jzz///H22f/LJJ9l2NTU17L1//OMfbD+am5v5UCjEjmn58uUDPnf55ZezY9i6dSsvk8n4qqoq9v7MmTP5//73v+xv+s01a9YMub/9r53ecaLz1drays5naWkpu6aGG/M5c+YMeO2iiy5iY7F69Wo+Ho+z8aTf2L59O3u+bds23mq18m+++eaQ54uu74KCAr66upodFz2/9dZbD3jbyspKXq1W81999RUfDof5W265hV2f/a/B/hjuOt+8eTO79tauXcuu66effpqdIxqP3vM13HeO4rth1JM7zLF69Wo0NjayVfuUKVPYKv6FF14Y8eeJCEEezEhBXt7SpUuZ59frCe7evTvpyUmlUlRVVcHj8TAPa/LkycN+F62ar7jiCpYzueiii9De3s7CeyMFrcoVCgVbkZMXcM455zDPh1bWlDvasmVLclt6nTws2j/yNsjDIo+Bfo9W5f/4xz/Yd9B25LW99NJLyc+Sd3b99deD4zgolcp99uP555/HHXfcwT5rsVjwu9/9Ds8+++yIj2Oo7yNPqqCggK34//SnP7H96R8qpf299957mWdVVFTEXiNvlrwz8rTIIycP6LXXXhvwOdo3OoYJEyawx7Zt29jrdF727t3LPBv6TYoMjBQ33HADO0cmk4l5KVu3bj2g46XribwrsVjMxpM8rvHjx7Pn5CHSuH7xxRfDfv6SSy5BSUkJOy66D/b3+8NtS+eJ9p2iCzKZDHfddRfzAofDcNf5o48+iquuugozZsxIXtc0FmvXrj2gczKKkWPUyB3moPATTfIUfiKce+65A0KWNDET+aE/6DndpAQKzZBxORDQb/QaOTKolL8j49cbPiOjQQbsqKOOYuGp4ZCWlpb8u/fzRBQYKSis1guatAY/7/9dZPj6T1q0f21tbWyBQOeDDD2Fl+hBk1RXV1dyWwqZ7Q/0PfR9g7/7u2Ko7yND1X8BQAbuuuuuGxA6pGNZtmxZ8jjKysrYRNv/c4PPee85euKJJ1h4lUK506ZNYyG3kWK47xwpBp/fdevW4ZhjjmELBgqBkvHuH1b8Pr8/3LZ0zvvvB703mCTUH8Nd5zQGf//735NjQI/m5ubvdT2MYv8YNXKHMYLBIF555RW2yqWblx6UH6PVee8KnfIvlJvoj/r6+uQkumDBArS0tGDjxo0j/t1Fixahu7ubrYLJ2JHR6wVNkG+//TYzEmT8aLX8XUBeVSAQSD7v6OjA90FrayvLy/WCciPkfdDERittmkRdLhd70Oq8srIyue3+VvQE+h6a3AZ/90gw1HcP9X20WOlvxImtSXk7mmx7QcdCec3e46AH5cfIwH8biouL2VjSuN16660444wz4Pf7cTAx3Hkc/DpdTxQZIONAuUjK5f3QzVRokUP3Qf97a3/lHsNd5zQGv/3tbweMAV3H5I2O5FoaxYFj1MgdxnjrrbfYSp3CJmRw6EEJfgrVURKdQKE5ImgQsYAmClqtkyFcvnx5cnK79tpr2U1Ioa9IJMImRgqPEYlhKJAXSCUGv/zlLxmxgYwegT5LoTaamGgbIg5QyOm7gBL3tA/kZZEBpnDS9wFNRsQepe979dVX2XkisgNNbuQJ33LLLcy4EXmjtrZ2v+GxwaBzRwaHDD8ZSwp19ZJkvg3krdA5qqurG/B9NEa0GCFP4ze/+Q0bRzJ0vSC25sqVK5k318uoJWNAE2yvgaT9oYl4JHjuuefY9rQvvaSQ7zp2w4GMNBkSuk72B6/Xy0KfFLqk6/ZAwu/fFWTU33nnHXzzzTds/yjUO5xh3d91TuF38jzJG6XP00KBwuJ0TL3noP9Yj+L7Y9TIHcagsCTlGMhb6/Xk6EEsQroJKcR1/PHHM2NF21HohyZ2yhNceeWVye+hyZ8+QxMmTXCU13vzzTdZjmI40GqbmGdk7PpPvpSLIlYa3fh0s9N+fBcQu42MDeU7KI/U31v8LqAcCeUPKaxLhoCMZm84ihYENHERW45+jya8AwnhEhuTmImUP6JcEuVn6LWRgMJitD+Uk6JzT7mbSy+9lDHxiHlJrFOa7P/5z3/u81nKqVFYkSZW8uCoFIQ8IDLaxPqjvBpNtiMBGUwynJSPo++hBcZQ+cfvA2KX0m/QNdobXh8KDz30EMtx0jHQguG7RgMOBLRfdI5p8UcLHzoPlGMlL38oDHed03Xw2GOPsfuJriXKlxJDtxe//vWv2YKIxno45vEoDgwiYp8c4GdGMYpRjOKIBnnQZIhoYUQLjVEcuhj15EYxilGMYgSgcCXlzyjESPWb5JWTtzaKQxujRm4UoxjFKEYAyl/
"text/plain": [
"<Figure size 504x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"auc_df = pd.read_parquet(f'results_{model_name}/df_both.parquet')\n",
"plt.figure(figsize=(7, 4))\n",
"plt.scatter(auc_df['count'], auc_df['auc'], \n",
" c=auc_df['color'], s=24, edgecolor='white', linewidth=0.65)\n",
"plt.axhline(0.5, color='k', linestyle='--', linewidth=0.75)\n",
"plt.title('AUC vs number of tokens in training set')\n",
"plt.xscale('log')\n",
"plt.ylim(0, 1.05)\n",
"plt.xlabel('Number of tokens in training set')\n",
"plt.ylabel('AUC')\n",
"plt.savefig(f'results_{model_name}/fig_auc_vs_data_size.png', dpi=600)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "22ff1f44",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAFbCAYAAAAtA38vAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAALEwAACxMBAJqcGAAAqxpJREFUeJztnQWclUX3x4dddunuLgkJExQ7UFFULBSwEwP1NV6xwRYbFQtBVFLFwEJEFBPBVkS6Qbp7gfv/fMf33P9w3WKfuXf3Pvf8Pp+FvbHPPDPPzDm/E3OmWCQSiRiFQqFQKBSKJEdaYd+AQqFQKBQKhQ8oqVEoFAqFQhEKKKlRKBQKhUIRCiipUSgUCoVCEQooqVEoFAqFQhEKKKlRKBQKhUIRCiipUSgUSYN77rnHnH/++dl+NmHCBFO3bt2E31NRx8UXX2zuuuuuwr4NhSIhUFKjUBQSjj76aFOpUiWzbdu2f70/cODAXBU25aWeeeYZ07p1a1OmTBn72dlnn23++OOPhN1/KhGB7du3W0LVtGlTO94NGzY0l156qZk3b170mZUsWdKUK1fOlC9f3hx44IGmb9++/3q2sejRo4dp3ry5SUtLM6+++uq/Pn/qqadMzZo17TVpL6/rJRqMw2effVbYt6FQRKGkRqEoBKAMv/76a1OsWDHz/vvv7/Hf/+c//zFPP/20JTarV682M2bMMKeffrr56KOPvN/rjh07TKqjS5cu9jkNHz7crFu3zvz222+WuIwfPz76nf79+5sNGzaYv//+2zzxxBNm5MiRplOnTpaA5oR9993XPP/88+aAAw7412djx461xIg25s+fb+bMmWP69OljwgLGZdeuXYV9G4qQQUmNQlEIeP3110379u2tR+C1117bo7+dOXOmee6558yIESPMsccea0qUKGFKly5tzjvvPHPbbbfl6xqffvqp9RBUqFDBXHPNNeaoo46KeofwGBx22GHmxhtvNFWqVLEeChT5hRdeaKpVq2YaNGhgHnjggahCig0JQdgga0KG8GLcfvvt5qCDDrIeh9NOO80SMcH3339vDj30UFOxYkWr5PFKCebOnWvvDQ/I8ccfb1auXJln3x566CFTtWpV60UYNmyYfe+HH34wNWrUMDt37ox+75133rHt5QU8EePGjTOjR4827dq1M8WLF7fj1rNnT3PZZZf96/t4cugzJGjixIm5Ek2u0aFDB+vliQXzguu3atXKevTuvvvubL05gm+++SY6jvXq1dvtu2vWrDEnn3yyHceDDz7YzJ49ezeCzPfFwwTZFvBsIXRdu3a1fwv5gtCBCy64wCxYsMCceuqppmzZsubRRx/N83kyLnfeeaedX8xZiJpC4RNKahSKQiI1kBB+sMiXLVuW77/FcifcBEkoCCAGKKqHH37YrFq1ypKb7777brfvTJo0yTRu3NjeF0rouuuus8QGJfTll1/a+x88eHC+2+T7r7zyivViQAquv/56+/7ixYutsiXUA9F5/PHHzVlnnWVWrFhhPz/33HOtouWeUep5EcClS5fa73Jdvkt4Z/r06ZaMQNAgc4IhQ4ZYopYfUsNYo/j3BPXr1zdt27bdjSTsCf7888/dSBe/8zx4ZrHAk3PSSSfZ58TY/frrr2a//faLfo7XCC8P5Gavvfayz1TA2PB9xp/xJoy5devW6OeQOd6Tz/EIZmVl2fGjjx988IHZuHGj6dWrV57PE/B3AwYMsF4tCLJC4RNKahSKBAOLGiV0zjnnWIXdpEkTG9bIL1BqtWrVKnD7H3/8sbX+zzzzzCjBIG/DRe3ata2C5PPMzEyrFCFBWOt4QG6++WarnPILrHrJ/7n//vvNm2++ab0mQ4cOtSEafsgrwRsDEeAe8QLgYeH7eKOOPPJI6xXIC/J9PDwoWNoCF110kW0PoHAhkyjpeI434+h6pfYEEAU8QgL5HTIQC+bPcccdZ7p3724yMjIsgXNJzRlnnGGJGc8TIg2JEeBl4/t8xnMlbwciKGCOQoK57k033WQJD96Y7JDb8xTgnWT+0R7XVCh8QkmNQpFg4EE44YQTbIgEoFhdDwTCHkvYBa9FAaCA8HgUFEuWLNnN60CoKHbXkPs5ng/ad61qfscqzy/c6/G3XI/rQu7eeustG6qQH0gf/eM+CbtAhNy/zQ3ZfZ/riPLGq7Bp0yZLdI444oh8kZUg480YVa5c2f5OiEZ+IGx5ge+tX78++lp+h1jGYuHChZYc5wSXtBL2gTAJ8KbsvffeljQx/njk3DCf++wgKswVGdNY5PY8s7ueQuEbSmoUigRiy5YtVqESwkHR8MMOF/IUJFcBl77sqnFzS0Shk4OxaNEi8+OPPxboHlDk/L2bsOm+FqIjgHxBqFBYApRynTp17O+QiM2bN+8WAspO6bp/y/W4LgoOL87atWujP5AOcoO4T8IlvHb/Njdk9328JYD7PeSQQ2wuDV4m2s0P8IBMnjz5X2OUF+jzTz/9ZMkTgEjID884L+DNkDkB+J28IEhWLBhHN08mvyA0Ri4Mc5KxY/whN25ys/vsyKNiHGRM3Xki95HT8xTE/o1C4RNKahSKBOK9994z6enpZurUqTYEwM9ff/1lFR95J4CkTPJVUKQoF3Y2QXy6detmP2dbMcm9hBpIwmS7MSEBQkTslgEkiRImyg6EZNj6zb2QzEvScXZERMD9EiojD4PQB+TmySefjCYHE+b46quvLIHAyidMlV1Ygj5Dfnr37m3DGVxXvCeEgghH0Q/6hOKExBG6IBeEPmLx8928IN9HYX/44Yc2H0RADg1KnP4TfssvqSGMQggHksKYMQ4vvviizROKBX2EtJIQTciHUExOkGfHc8Z7xe+SgM29Dho0yI4b5IDkbEI32YGQErk/kBPuj5CZG2LKCfQDzyAJ4Pzdfffdt5t3CNBniCCf9+vXz4b2SHIHkCw32Te356lQJAQRhUKRMHTs2DFy0003/ev9N954I1KjRo1IVlaWfT1o0KBIy5YtI+XKlYs0adIk8vDDD0d27twZ/f6uXbsi/fr1s98pVapUpHbt2pFzzjknMmXKFPv5fffdFzn33HNzvI8xY8ZEmjZtGilfvnzk6quvjrRv3z7y+uuv288GDx4cOeyww3b7/urVqyPnnXdepGrVqpG6detG7r333t3u55prrolUqFDB3uuAAQMw86N9OeqooyK33XZbpF27drY/p5xySmTFihXRv/3+++8jRx55ZKRSpUr2+p06dYrMnz/ffjZ79uzI4YcfHilTpkzkuOOOi/Ts2dPeR3b44osvInXq1Ik88MADkSpVqkTq1asX7ZNg06ZN9h4uvPDCXJ/TRRddFLnzzjujr7dt2xbp3bu37V/p0qUj9evXj1x22WXR+6SPJUqUiJQtW9b+7LfffvY+tmzZkms7/B1j5f7QD8ETTzwRqV69ur3niy++OLJ169Ycr/XVV19FDjroIPtdntGrr76abV9knMCOHTsil1xyif2bmjVrRh555JFIgwYNIuPGjbOf9+nTJ3LWWWfZuSX9+umnn6LXeu+99+w48+wfe+yxPJ8n/X355ZdzHROFIgiK8U9i6JNCoUgUyNmhjg25EnkBzwB5Emx/PuaYY7zfC9t4seAvv/xyUxRA7slLL71kPTCK3MGW7lmzZkUTrBWKoo7ihX0DCoXCP9yty9mB8AD1SkqVKmUee+wxG/6QkEKY8fbbb9ucDur7KBSK8EFJjUKRgqAoHLuuyOlo2bKlza+B4IQZeIzITyFJmF08CoUifNDwk0KhUCgUilBAzRWFQqFQKBShgJIahUKhUCgUoYCSGoVCoVAoFKGAkhqFQqFQKBShgJIahUKhUCgUoYCSGoVCoVAoFKGAkhqFQqFQKBShgJIahUKhUCgUoYCSGoVCoVAoFKGAkhqFQqFQKBShgJIahUKhUCgUoUDoSc2ll15qqlevblq3bp3t5xx9df3115u99trL7LPPPubnn39O+D0qFAqFQqEIjtCTmosvvth88sknOX4+ZswYM3PmTPszYMAAc/XVVyf0/hQKhUKhUPhB6EnNkUceaSpXrpzj56NHjzYXXnihKVa
"text/plain": [
"<Figure size 576x360 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import matplotlib.patches as mpatches\n",
"\n",
"\n",
"\n",
"chapters = auc_df['ICD-10 Chapter (short)'].unique()\n",
"chapter_data = {}\n",
"\n",
"for chapter in chapters:\n",
" if chapter not in ['Technical', 'Sex', 'Smoking, Alcohol and BMI']: # Skip non-disease chapters\n",
" chapter_data[chapter] = auc_df[auc_df['ICD-10 Chapter (short)'] == chapter]['auc'].values\n",
"\n",
"fig, ax = plt.subplots(figsize=(8, 5))\n",
"\n",
"chapter_colors = {}\n",
"for chapter in chapter_data.keys():\n",
" chapter_rows = auc_df[auc_df['ICD-10 Chapter (short)'] == chapter]\n",
" chapter_colors[chapter] = chapter_rows['color'].iloc[0]\n",
"\n",
"positions = np.arange(1, len(chapter_data) + 1)\n",
"boxplots = []\n",
"\n",
"for i, (chapter, values) in enumerate(chapter_data.items()):\n",
" bp = ax.boxplot(values, positions=[positions[i]], patch_artist=True, \n",
" widths=0.6, whis=[2.5, 97.5], showfliers=True, \n",
" boxprops={'linewidth': 1.25, 'facecolor': chapter_colors[chapter], \n",
" 'edgecolor': chapter_colors[chapter]},\n",
" medianprops={'color': 'black', 'linewidth': 1.5},\n",
" whiskerprops={'color': 'gray', 'linewidth': 1},\n",
" capprops={'color': 'gray', 'linewidth': 1},\n",
" flierprops={'marker': 'x', 'markerfacecolor': 'none', \n",
" 'markeredgecolor': 'black', 'markersize': 3, 'alpha': 0.3})\n",
"\n",
" boxplots.append(bp)\n",
"\n",
"ax.set_xticks(positions)\n",
"ax.set_xticklabels([chapter for chapter in chapter_data.keys()], rotation=45, ha='right')\n",
"\n",
"ax.set_ylim(0, 1.025)\n",
"ax.axhline(0.5, color='black', linestyle='--', linewidth=0.75)\n",
"\n",
"ax.yaxis.grid(True, linestyle='--', alpha=0.7)\n",
"ax.set_axisbelow(True)\n",
"\n",
"ax.set_ylabel('AUC')\n",
"ax.set_xlabel('ICD-10 chapter')\n",
"ax.set_title('AUC, grouped by ICD-10 chapter', y=1.05)\n",
"\n",
"plt.tight_layout()\n",
"plt.grid(axis='x', visible=False)\n",
"plt.savefig(f'results_{model_name}/fig_auc_by_icd10_chapter.png', dpi=600)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "e90e7cbf",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'df_auc_all_diseases' is not defined",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[11]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mdf_auc_all_diseases\u001b[49m\n",
"\u001b[31mNameError\u001b[39m: name 'df_auc_all_diseases' is not defined"
]
}
],
"source": [
"df_auc_all_diseases"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}