2025-10-20 10:14:50 +08:00
|
|
|
{
|
|
|
|
|
"cells": [
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 1,
|
|
|
|
|
"id": "08bd379a",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"import torch\n",
|
|
|
|
|
"from models import TimeAwareGPT2\n",
|
|
|
|
|
"from utils import load_model\n",
|
|
|
|
|
"from tqdm import tqdm\n",
|
|
|
|
|
"import pandas as pd\n",
|
|
|
|
|
"import numpy as np\n",
|
|
|
|
|
"import textwrap\n",
|
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"plt.rcParams['figure.facecolor'] = 'white'\n",
|
|
|
|
|
"plt.rcParams.update({'axes.grid': True,\n",
|
|
|
|
|
" 'grid.linestyle': ':',\n",
|
|
|
|
|
" 'axes.spines.bottom': False,\n",
|
|
|
|
|
" 'axes.spines.left': False,\n",
|
|
|
|
|
" 'axes.spines.right': False,\n",
|
|
|
|
|
" 'axes.spines.top': False})\n",
|
|
|
|
|
"plt.rcParams['figure.dpi'] = 72\n",
|
|
|
|
|
"plt.rcParams['pdf.fonttype'] = 42\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"#Green\n",
|
|
|
|
|
"light_male = '#BAEBE3'\n",
|
|
|
|
|
"normal_male = '#0FB8A1'\n",
|
|
|
|
|
"dark_male = '#00574A'\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"#Purple\n",
|
|
|
|
|
"light_female = '#DEC7FF'\n",
|
|
|
|
|
"normal_female = '#8520F1'\n",
|
|
|
|
|
"dark_female = '#7A00BF'\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" \n",
|
|
|
|
|
"delphi_labels = pd.read_csv('delphi_labels_chapters_colours_icd.csv')"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 2,
|
|
|
|
|
"id": "1d8375ab",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"text/html": [
|
|
|
|
|
"<div>\n",
|
|
|
|
|
"<style scoped>\n",
|
|
|
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
|
|
|
" vertical-align: middle;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" .dataframe tbody tr th {\n",
|
|
|
|
|
" vertical-align: top;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" .dataframe thead th {\n",
|
|
|
|
|
" text-align: right;\n",
|
|
|
|
|
" }\n",
|
|
|
|
|
"</style>\n",
|
|
|
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
|
|
|
" <thead>\n",
|
|
|
|
|
" <tr style=\"text-align: right;\">\n",
|
|
|
|
|
" <th></th>\n",
|
|
|
|
|
" <th>name</th>\n",
|
|
|
|
|
" <th>ICD-10 Chapter (short)</th>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </thead>\n",
|
|
|
|
|
" <tbody>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>46</th>\n",
|
|
|
|
|
" <td>A41 Other septicaemia</td>\n",
|
|
|
|
|
" <td>I. Infectious Diseases</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>95</th>\n",
|
|
|
|
|
" <td>B01 Varicella [chickenpox]</td>\n",
|
|
|
|
|
" <td>I. Infectious Diseases</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>1168</th>\n",
|
|
|
|
|
" <td>C25 Malignant neoplasm of pancreas</td>\n",
|
|
|
|
|
" <td>II. Neoplasms</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>1188</th>\n",
|
|
|
|
|
" <td>C50 Malignant neoplasm of breast</td>\n",
|
|
|
|
|
" <td>II. Neoplasms</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>374</th>\n",
|
|
|
|
|
" <td>G30 Alzheimer's disease</td>\n",
|
|
|
|
|
" <td>VI. Nervous System Diseases</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>214</th>\n",
|
|
|
|
|
" <td>E10 Insulin-dependent diabetes mellitus</td>\n",
|
|
|
|
|
" <td>IV. Metabolic Diseases</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>305</th>\n",
|
|
|
|
|
" <td>F32 Depressive episode</td>\n",
|
|
|
|
|
" <td>V. Mental Disorders</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>505</th>\n",
|
|
|
|
|
" <td>I21 Acute myocardial infarction</td>\n",
|
|
|
|
|
" <td>IX. Circulatory Diseases</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>603</th>\n",
|
|
|
|
|
" <td>J45 Asthma</td>\n",
|
|
|
|
|
" <td>X. Respiratory Diseases</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" <tr>\n",
|
|
|
|
|
" <th>1269</th>\n",
|
|
|
|
|
" <td>Death</td>\n",
|
|
|
|
|
" <td>Death</td>\n",
|
|
|
|
|
" </tr>\n",
|
|
|
|
|
" </tbody>\n",
|
|
|
|
|
"</table>\n",
|
|
|
|
|
"</div>"
|
|
|
|
|
],
|
|
|
|
|
"text/plain": [
|
|
|
|
|
" name ICD-10 Chapter (short)\n",
|
|
|
|
|
"46 A41 Other septicaemia I. Infectious Diseases\n",
|
|
|
|
|
"95 B01 Varicella [chickenpox] I. Infectious Diseases\n",
|
|
|
|
|
"1168 C25 Malignant neoplasm of pancreas II. Neoplasms\n",
|
|
|
|
|
"1188 C50 Malignant neoplasm of breast II. Neoplasms\n",
|
|
|
|
|
"374 G30 Alzheimer's disease VI. Nervous System Diseases\n",
|
|
|
|
|
"214 E10 Insulin-dependent diabetes mellitus IV. Metabolic Diseases\n",
|
|
|
|
|
"305 F32 Depressive episode V. Mental Disorders\n",
|
|
|
|
|
"505 I21 Acute myocardial infarction IX. Circulatory Diseases\n",
|
|
|
|
|
"603 J45 Asthma X. Respiratory Diseases\n",
|
|
|
|
|
"1269 Death Death"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"execution_count": 2,
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "execute_result"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Delphi is capable of predicting the disease risk for 1,256 diseases from ICD-10 plus death. \n",
|
|
|
|
|
"# For illustrative purposes, some of the plots will focus on a subset of 10 selected diseases - the same subset in used in the Delphi paper. \n",
|
|
|
|
|
"\n",
|
|
|
|
|
"diseases_of_interest = [46, 95, 1168, 1188, 374, 214, 305, 505, 603, 1269]\n",
|
|
|
|
|
"delphi_labels.iloc[diseases_of_interest][['name', 'ICD-10 Chapter (short)']]"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2025-10-22 13:27:28 +08:00
|
|
|
"execution_count": null,
|
2025-10-20 10:14:50 +08:00
|
|
|
"id": "f8d86352",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
2025-10-22 13:27:28 +08:00
|
|
|
"Using device: cpu\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"ename": "TypeError",
|
|
|
|
|
"evalue": "load_model() got multiple values for argument 'device'",
|
|
|
|
|
"output_type": "error",
|
|
|
|
|
"traceback": [
|
|
|
|
|
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
|
|
|
|
"\u001b[31mTypeError\u001b[39m Traceback (most recent call last)",
|
|
|
|
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 6\u001b[39m\n\u001b[32m 4\u001b[39m device = torch.device(\u001b[33m'\u001b[39m\u001b[33mcpu\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 5\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mUsing device: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdevice\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m model = \u001b[43mload_model\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43mf\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mresults_n_embd_120_n_layer_12_n_head_12_learnable/config_n_embd_120_n_layer_12_n_head_12.json\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[32m 7\u001b[39m \u001b[43m \u001b[49m\u001b[32;43m1270\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mcpu\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 8\u001b[39m model.eval()\n\u001b[32m 9\u001b[39m model = model.to(device)\n",
|
|
|
|
|
"\u001b[31mTypeError\u001b[39m: load_model() got multiple values for argument 'device'"
|
2025-10-20 10:14:50 +08:00
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"seed = 1337\n",
|
|
|
|
|
"torch.manual_seed(seed)\n",
|
|
|
|
|
"torch.cuda.manual_seed(seed)\n",
|
2025-10-22 13:27:28 +08:00
|
|
|
"device = torch.device('cpu')\n",
|
2025-10-20 10:14:50 +08:00
|
|
|
"print(f'Using device: {device}')\n",
|
2025-10-22 13:27:28 +08:00
|
|
|
"model = load_model(f'results_n_embd_120_n_layer_12_n_head_12_learnable/config_n_embd_120_n_layer_12_n_head_12.json', \n",
|
2025-10-20 10:14:50 +08:00
|
|
|
" 1270)\n",
|
|
|
|
|
"model.eval()\n",
|
|
|
|
|
"model = model.to(device)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2025-10-22 13:27:28 +08:00
|
|
|
"execution_count": 43,
|
2025-10-20 10:14:50 +08:00
|
|
|
"id": "a8658fb6",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"# Let's try to use the loaded model to extrapolate a partial health trajectory.\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"example_health_trajectory = [\n",
|
|
|
|
|
" ('Male', 0),\n",
|
|
|
|
|
" ('B01 Varicella [chickenpox]',2),\n",
|
|
|
|
|
" ('L20 Atopic dermatitis',3),\n",
|
|
|
|
|
" ('No event', 5),\n",
|
|
|
|
|
" ('No event', 10),\n",
|
|
|
|
|
" ('No event', 15),\n",
|
|
|
|
|
" ('No event', 20),\n",
|
|
|
|
|
" ('G43 Migraine', 20),\n",
|
|
|
|
|
" ('E73 Lactose intolerance',21),\n",
|
|
|
|
|
" ('B27 Infectious mononucleosis',22),\n",
|
|
|
|
|
" ('No event', 25),\n",
|
|
|
|
|
" ('J11 Influenza, virus not identified',28),\n",
|
|
|
|
|
" ('No event', 30),\n",
|
|
|
|
|
" ('No event', 35),\n",
|
|
|
|
|
" ('No event', 40),\n",
|
|
|
|
|
" ('Smoking low', 41),\n",
|
|
|
|
|
" ('BMI mid', 41),\n",
|
|
|
|
|
" ('Alcohol low', 41),\n",
|
|
|
|
|
" ('No event', 42),\n",
|
|
|
|
|
"]\n",
|
|
|
|
|
"example_health_trajectory = [(a, b * 365.25) for a,b in example_health_trajectory] "
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2025-10-22 13:27:28 +08:00
|
|
|
"execution_count": 44,
|
2025-10-20 10:14:50 +08:00
|
|
|
"id": "d7f2e4a1",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
"text": [
|
|
|
|
|
"Input trajectory:\n",
|
|
|
|
|
"0.0: Male\n",
|
|
|
|
|
"2.0: B01 Varicella [chickenpox]\n",
|
|
|
|
|
"3.0: L20 Atopic dermatitis\n",
|
|
|
|
|
"5.0: No event\n",
|
|
|
|
|
"10.0: No event\n",
|
|
|
|
|
"15.0: No event\n",
|
|
|
|
|
"20.0: No event\n",
|
|
|
|
|
"20.0: G43 Migraine\n",
|
|
|
|
|
"21.0: E73 Lactose intolerance\n",
|
|
|
|
|
"22.0: B27 Infectious mononucleosis\n",
|
|
|
|
|
"25.0: No event\n",
|
|
|
|
|
"28.0: J11 Influenza, virus not identified\n",
|
|
|
|
|
"30.0: No event\n",
|
|
|
|
|
"35.0: No event\n",
|
|
|
|
|
"40.0: No event\n",
|
|
|
|
|
"41.0: Smoking low\n",
|
|
|
|
|
"41.0: BMI mid\n",
|
|
|
|
|
"41.0: Alcohol low\n",
|
|
|
|
|
"42.0: No event\n",
|
|
|
|
|
"=====================\n",
|
|
|
|
|
"Generated trajectory:\n",
|
2025-10-22 13:27:28 +08:00
|
|
|
"43.9: B95 Streptococcus and staphylococcus as the cause of diseases classified to other chapters\n",
|
|
|
|
|
"44.6: J30 Vasomotor and allergic rhinitis\n",
|
|
|
|
|
"44.8: L03 Cellulitis\n",
|
|
|
|
|
"46.6: L30 Other dermatitis\n",
|
|
|
|
|
"46.8: K56 Paralytic ileus and intestinal obstruction without hernia\n",
|
|
|
|
|
"47.8: K76 Other diseases of liver\n",
|
|
|
|
|
"48.8: N20 Calculus of kidney and ureter\n",
|
|
|
|
|
"48.8: L24 Irritant contact dermatitis\n",
|
|
|
|
|
"49.5: I10 Essential primary hypertension\n",
|
|
|
|
|
"49.5: K59 Other functional intestinal disorders\n",
|
|
|
|
|
"50.6: B96 Other bacterial agents as the cause of diseases classified to other chapters\n",
|
|
|
|
|
"51.4: E14 Unspecified diabetes mellitus\n",
|
|
|
|
|
"51.6: E55 Vitamin d deficiency\n",
|
|
|
|
|
"51.7: A41 Other septicaemia\n",
|
|
|
|
|
"52.0: F41 Other anxiety disorders\n",
|
2025-10-20 13:47:50 +08:00
|
|
|
"52.1: F32 Depressive episode\n",
|
2025-10-22 13:27:28 +08:00
|
|
|
"52.3: M19 Other arthrosis\n",
|
|
|
|
|
"53.5: G96 Other disorders of central nervous system\n",
|
|
|
|
|
"53.7: E87 Other disorders of fluid, electrolyte and acid-base balance\n",
|
|
|
|
|
"53.8: J18 Pneumonia, organism unspecified\n",
|
|
|
|
|
"53.8: E66 Obesity\n",
|
|
|
|
|
"53.8: E11 Non-insulin-dependent diabetes mellitus\n",
|
|
|
|
|
"54.4: Death\n"
|
2025-10-20 10:14:50 +08:00
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"max_new_tokens = 100\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"name_to_token_id = {row[1]['name']: row[1]['index'] for row in delphi_labels.iterrows()}\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"events = [name_to_token_id[event[0]] for event in example_health_trajectory]\n",
|
|
|
|
|
"events = torch.tensor(events, device=device).unsqueeze(0)\n",
|
|
|
|
|
"ages = [event[1] for event in example_health_trajectory]\n",
|
|
|
|
|
"ages = torch.tensor(ages, device=device).unsqueeze(0)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"res = []\n",
|
|
|
|
|
"with torch.no_grad():\n",
|
|
|
|
|
" y,b,_ = model.generate(events, ages, max_new_tokens, termination_tokens=[1269])\n",
|
|
|
|
|
" # Convert model outputs to readable format\n",
|
|
|
|
|
" events_data = zip(y.cpu().numpy().flatten(), b.cpu().numpy().flatten()/365.)\n",
|
|
|
|
|
" \n",
|
|
|
|
|
" print('Input trajectory:')\n",
|
|
|
|
|
" for i, (event_id, age_years) in enumerate(events_data):\n",
|
|
|
|
|
" if i == len(example_health_trajectory):\n",
|
|
|
|
|
" print('=====================')\n",
|
|
|
|
|
" print('Generated trajectory:')\n",
|
|
|
|
|
" event_name = delphi_labels.loc[event_id, 'name']\n",
|
|
|
|
|
" print(f\"{age_years:2.1f}: {event_name}\")"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2025-10-22 13:27:28 +08:00
|
|
|
"execution_count": 45,
|
2025-10-20 10:14:50 +08:00
|
|
|
"id": "0b937d75",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
2025-10-20 13:47:50 +08:00
|
|
|
"source": [
|
|
|
|
|
"from utils import PatientEventDataset, get_batch\n",
|
|
|
|
|
"train_data_path = 'ukb_real_train.bin'\n",
|
|
|
|
|
"val_data_path = 'ukb_real_val.bin'\n",
|
|
|
|
|
"train_data_arr = np.memmap(train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)\n",
|
|
|
|
|
"val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)\n",
|
|
|
|
|
"block_length = 128\n",
|
|
|
|
|
"train_dataset = PatientEventDataset(train_data_arr, block_length)\n",
|
|
|
|
|
"val_dataset = PatientEventDataset(val_data_arr, block_length)\n",
|
|
|
|
|
"dataset_subset_size = 2048"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2025-10-22 13:27:28 +08:00
|
|
|
"execution_count": 46,
|
2025-10-20 13:47:50 +08:00
|
|
|
"id": "6a3bb2dc",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"input_events, input_times, target_events, target_times = get_batch(val_dataset, slice(256))\n",
|
|
|
|
|
"with torch.no_grad():\n",
|
|
|
|
|
" p = model(input_events.to(device), input_times.to(device))\n",
|
|
|
|
|
" p = p.cpu().detach().numpy().squeeze()\n",
|
|
|
|
|
"t = (target_times-input_times).cpu().numpy().squeeze()\n",
|
|
|
|
|
"target_events = target_events.cpu().numpy().squeeze()\n",
|
|
|
|
|
"ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]\n",
|
|
|
|
|
"mask = ~np.isin(target_events, ignored_token_ids)\n",
|
|
|
|
|
"p = p[mask]\n",
|
|
|
|
|
"t = t[mask]"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2025-10-21 09:20:43 +08:00
|
|
|
"execution_count": null,
|
2025-10-20 13:47:50 +08:00
|
|
|
"id": "1c32bd6e",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
2025-10-22 13:27:28 +08:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ4AAAELCAYAAAAofGgWAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAALEwAACxMBAJqcGAAAUUhJREFUeJztXQe0FEXWvgRBsiigmMCAJAmKIAgICoqrK6irmAVFZMW0KJhFwfgvSlJ0XVZXXTMuimFRJCOCgiggCKIgBiTnR3jAm/98BXesV1R1mp6e6Zn6znnnTeiurr7TdevmWyKRSCTIwsLCwgdK+jnYwsLCArCMw8LCwjcs47CwsPANyzgsLCx8wzIOCwuL3GMc33//faanYGFhETfGAezatSvTU8gaOlha/AFLj8zRokQc4jh27NhBBx54IOU7QAfA0iI1emzbto0KCgqoQoUKVL58+f3eq8eVKlWKtmzZIhbmwQcfLP74+w0bNojXVatWFf9xPOaF/yVKlKA9e/ZQUVERpRu7d+8W/0uXLp3SOCtXrqRp06bRJZdcQocffnjyXlWkdpWIEAPeFgksHVKnBxb7smXLxLlY2DVr1qTff/89+f6YY45JMhM+TsaKFSuSzEr+HgwE58f5NwLTuPfee2nAgAHJezUxjlioKhYWYQGSAC9u/N+8eXOx9/hePU7F5s2btd/HmWms37yNbr/zHrrgr3fTnK1V6J+z1zseHwuJw8IiLEAdYckA/ytXrpxkAniP79XjVFSuXFlIHHGXML78dRu9MW8Trdy6m7YUFlFR+7vpn9+XhfxEB5YuQTecopc2YmPjKCwspDJlyhT7DPrmr7/+mtRz8wH8U+GB9Qs86EceeSQdcMABlCvAcwGoz0bYNg7YKdauXZv8/Nhjj02eB/Ee/1WULFkyEtsGA3MEYI/xgu/W7KT7xq+iGhVKUaNDD6RDK5be+1dh7//KZUvSEUccEW8bh26hgGlUqlSJateuHWghxRH8IOKh9Mtw1q1bJ2gGHT5XEPR3x6KXGYT6Xv18zZo1xT4vKChIHq9jGkCVKlWEwRJGVd28eRPAM4w/tp34BRgFMw2v+H3LLnp48mqqWLiemhZ8TX/tfK32OCejcyxsHDrODUnjkEMOyRumkQpAI9Aq16QzPBded3UscDAA00J3OobVFkBWZwr22UN0gLFUxzQAWcgHA9q+fTsFhV+msXnnHnpo4moq3LSG9kwYRud16mg81un+YiFxmB4OyzS8Ixdp5YdpyJ4U9px4PQb/8V5VZ0p5VAvc5sYqVyrwYnHYtSdBj05ZQys376AKnz1LTzw6kI466ijj8U73FwvGYWHhFTp7hepJkVUNhtsx8lj83u9ur4NJKgkbRYkEDZ2xlhas3kn92h5GrS591tU25HR/lnFkGfBjhbGT5SNMUoPqSWFVQ4bbMbqxS2l+J9grZGZQrVo18ZvifBj0ozSYynht7kaa8tM2uqbpQdSu9v73r4OOTrGycfg1BkaJCy64gJo1a0YNGzakf/7zn/SPf/yD+vXrl/z+pZdeoptvvlm8fvXVV6lFixbUtGlT6tWrV5KjV6xYke644w5q0qQJzZgxgwYOHEjNmzenE088kW644YbkTjhr1ixxLv5wDXwPYBy8xzmNGzem559/PlIaeLEfpOu5kJ8NndQgqxqHHnqoVk3xcow69oYNG0TgmG4ceF0wDv7Ddbtx40bauXNnKEwDjEknKYCZmdTRcT9spdc+/5EOWTiKLmlY2dN14H3T0SlWEocb4xjwwQJauGJzqNdscHhlevD8hq7Hvfjii8JlBQMXFu6ECROodevWNGjQIPH9W2+9Rffddx9999134vX06dPFj9K7d2967bXX6JprrhEP5amnnkpPPfXU3ms3aED9+/cXr6+++mr68MMP6fzzz6cePXoIpoDx77777uQcXnjhBWHFB2PBA4rvzz777Eg8KF7sB1E9F05Sg8lzIsPpGDWuY/v27fvZFfiasloD+oQZ8YBNwo9NZP6qHTTs00W089Oh9NDQx0KzdcWCcWRzqMnw4cPp3XffFa9/+eUXsYiw08ycOZPq1KlDixYtEgt5xIgR9NVXXwnmwg9ejRo1xGuIvH/5y1+SY06aNIn+/ve/i4du/fr1Qppp27atEIFbtWoljrniiisEQwHGjRtH8+bNo3feeUe837RpEy1ZsiQSxuHFfhDVc2EyYoYBHhuSxsaNG/fzUCFXpVy5cslcFQ5jDwMYF65RXNuPVLe2YDc9MXUVbZnwHI188lE6ptbRvnNfYs043IxQXiSDdGDy5Mk0fvx4oV7gwWrfvr14aC677DJ6++23qV69enThhRcmd6pu3brR448/vt84eChYX8b5kEZmz54tLN4PPfSQqxsVYz/99NPUqVMnihpe7AdRPhdeJIugRlb8V0PNK1asmNwAdLktYUlWTq5RQL0uPCiPT1tDhXtK0CvPP0O1Dynn+5qO3/sazaIYsLNjp8EDBckCUgYAZjFmzBh64403BBMBOnToICSC1atXi/eQJJYvX77fmMwkYFTbunVrUoo46KCDhH77xRdfiPdvvvlm8hwwjOeeey6ZVo0aJm4PWljwYj8IA6nEYQS5FpjAqlWrxH8ej12nJaSYDjANHUMJExjbr8t25FfrafHaQrqt1SG+mYac7RtriSNbcc455whjaP369alu3brUsmXLJNHx2cKFC4UxlO0WjzzyiLA9wEgGOwfUl1q1ahUbEwyiZ8+ewvB52GGHJVUbYOTIkcKoit2gXbt2wq4BXH/99fTTTz/RySefLB7e6tWr03vvvRcZHcLa5dMVh+EXUAlUQ6gqUVStWjW5abjltqjA74ZNJ10YNfNHeu2fL1OPm/tQm1rBJEA3z17Gc1VgNBw7diwtXbpU2At0IhIMfmXLlt3vPCzOfAKyMiEag0ZPPPGE0KOHDRvm+fy40gxSBHZ/BqQbMEc8FwCeDdMxQfDbb78lmQUABgFPhtv4y5cv9xSXgZoZbjYEv+DxZi/5jW7tew+1urofDe7alEqVDGYM5XycyFUVxN5jB4T+LhOpT58+wtB32223ifd4kPHDcD2DuLljo8RHH30kaAppBMVW7r//fsoHmEK+ZXes6RgnmFQbVUzHe7fxt+2L0/CCsJkGzwnh5HcO+D867sK/0cALGgdmGpA23KS1tKkqcFHCNQl9nzFnzhyht+Ohv/HGG4X7EKL4tddem6yNwOI3A/o6iq0wsUEg3BiOz1QwTabQtWtX8SczWD80AM3gzeEFh/eygVFeiPLDzTQHcLwspHLFKfVzHI/z1HwSt2urn2MMqHXI7MWzAw8DPpPdoTiei/JgAcMWhHPU2AmeE87DWPCC4TWuC5URY/NYqH4FOxTe4z8+wzON9+XLlxfjQ93AOHwMz5dpLUP+zeTvTJ97GUv+vHD3bhr8+Qaqev6d1L9jDapQOlFszfi5NuxrYIJOmdRpYxyQNNTsOhgPzzrrLPG6Y8eOwhsB4oOhQEe/6qqrtGOBAKmWRLOIDljUWLRYYGGVOcRYauATLwz+HAsff14ib+X4CvzHe5zLwPzBFLC4YLBm2wk8XeXKlRPHQzXRMW4wojDC0b0Cc3t9/hb6ZlUh3dS8KtU5xF+ZgSCIdDXC/w3dCYBksWDBAsFA8GfCCSecIB5E+UeVI+VyMXkr7LR6gB98lY4mmHYbt8+xoDiGAbtwmJ4W07WDMCdID9i0mC54HsGA2A2LP91mxZLVhg0b9otclRHVRvfTryvojseepp2n/ZXOOrY8/alucYk9CFB7BBGvThJHpMYD/DhQRwD8hwchaHsEPCyoMZHNwWHZAq7HEUWRY1PYd7aEqTu5kWU3rC52hlW2ZcuWCWkkCFQjfyr4+vtf6Kre/aigfhe6olElurG5+3ryYyDOGokDUY8ImYaejsCp7t27ezpPJ1VA50VhGrXISi4jjApg2RAQ5td1aqrQFYQO6ljyeGosBmwleM8BepA2EIeTSGGzwvXYG5QKPlteQHcPGEKHnncb3d+5ETWqHm5lNzdDb9oYBy78pz/9iebOnSsClB577DGRj4EfAV4VJGp
|
2025-10-20 13:47:50 +08:00
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 288x288 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"from scipy.special import logsumexp\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Calculate expected waiting times from model predictions using competing exponentials theory\n",
|
|
|
|
|
"# In Delphi's framework, each possible event has an exponential distribution with rate λᵢ = exp(logits[i])\n",
|
|
|
|
|
"# The expected time until any event occurs is 1/sum(λᵢ) = 1/exp(logsumexp(logits))\n",
|
|
|
|
|
"# logsumexp provides numerical stability vs. calculating exp(logits) directly\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Let's see how the predicted waiting times compare to the observed waiting times\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"plt.figure(figsize=(4, 4))\n",
|
|
|
|
|
"# Calculate expected time to next token (inverse of hazard rate)\n",
|
|
|
|
|
"expected_t = 1/np.exp(logsumexp(p, axis=-1))\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Define bin width for logarithmic binning\n",
|
|
|
|
|
"delta_log_t = 0.1\n",
|
|
|
|
|
"log_range = np.arange(1.75, 4, delta_log_t)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"# Calculate average observed time for each logarithmic bin\n",
|
|
|
|
|
"observed_t = []\n",
|
|
|
|
|
"for i in log_range:\n",
|
|
|
|
|
" # Create mask for current bin and valid times\n",
|
|
|
|
|
" bin_mask = (expected_t > 10**i) & (expected_t <= 10**(i+delta_log_t)) & (t > 0)\n",
|
|
|
|
|
" # Calculate mean for this bin\n",
|
|
|
|
|
" bin_mean = t[bin_mask].mean() if bin_mask.sum() > 0 else np.nan\n",
|
|
|
|
|
" observed_t.append(bin_mean)\n",
|
|
|
|
|
"plt.axes().set_aspect('equal')\n",
|
|
|
|
|
"plt.scatter(expected_t, t+0.5, marker=\".\", c='lightgrey', rasterized=True)\n",
|
|
|
|
|
"plt.xlabel('Expected days to next token')\n",
|
|
|
|
|
"plt.ylabel('Observed days to next token')\n",
|
|
|
|
|
"plt.plot(10**(np.arange(1.75,4,delta_log_t)+delta_log_t/2.),observed_t, label='average')\n",
|
|
|
|
|
"plt.yscale('log')\n",
|
|
|
|
|
"plt.xscale('log')\n",
|
|
|
|
|
"plt.legend()\n",
|
|
|
|
|
"plt.xlim(1,2e3)\n",
|
|
|
|
|
"plt.ylim(1,2e3)\n",
|
|
|
|
|
"plt.plot([0,1],[0,1], transform = plt.gca().transAxes, c='k' , ls=(0, (5, 5)), linewidth=0.7)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"plt.gca().tick_params(length=1.15, width=0.3, labelsize=8, grid_alpha=1, grid_linewidth=0.45, grid_linestyle=':')\n",
|
2025-10-21 09:20:43 +08:00
|
|
|
"plt.gca().tick_params(length=1.15, width=0.3, labelsize=8, grid_alpha=0.0, grid_linewidth=0.35, which='minor')\n",
|
2025-10-22 13:27:28 +08:00
|
|
|
"plt.savefig(f'results_n_embd_120_n_layer_12_n_head_12_learnable/fig_expected_vs_observed_waiting_times.png', dpi=600)"
|
2025-10-20 13:47:50 +08:00
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "4e5e3c8f",
|
|
|
|
|
"metadata": {},
|
2025-10-21 09:20:43 +08:00
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
2025-10-22 13:27:28 +08:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbkAAAEeCAYAAAAXTWt+AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAALEwAACxMBAJqcGAAA26xJREFUeJzsXQeYFFXWPd1VnXP39OQcgJkBhpwzAgZUzJizrrrqmv513WBYdYO7665h3TUra84BAwYQQcl5CAOTc+qcu7r7/+6r6Z7AAIOiIsz5vv5mOlVX1at69917zz1XEovFYhjEIAYxiEEM4hiE9KfegUEMYhCDGMQgfigMGrlBDGIQgxjEMYtBIzeIQQxiEIM4ZjFo5AYxiEEMYhDHLAaN3CAGMYhBDOKYxaCRG8QgBjGIQRyzGDRygzhu8Pzzz2PatGk/2e8/8cQTSElJgVarRWdn5/faVk1NDSQSCQRBwE+Br7/+GkOHDv3Rf3fWrFl4+umn2f8vvfQS5s+fP6DPHi7q6urYOEUike+8r4M4OjBo5I4x0I1tMpkQDAYPecOvWLECmZmZiedUMvnII49g+PDh0Gg07L1zzjkH27dv/9H2/1hFOBzGrbfeimXLlsHj8cBisRxVRutwMX36dOzZs+cn3YcLL7yQnc8jgdzcXHz++eeJ59nZ2WycOI7D0YC+9+ogBo5BI3cMgSZKWmHTZPn+++8f9vdvvvlm/Otf/2KGzmazoaKiAosWLcLSpUt/kP39OeNwjVFraysCgQBKS0t/sH0axCAGsT8GjdwxhBdffBGTJk3CZZddhhdeeOGwvrt37148/vjjeOWVVzBnzhwoFAqo1Wq2Wr7zzjv3+/xrr72GcePG9Xrt4Ycfxmmnncb+/+ijj1BSUgKdToeMjAz87W9/O2gI8fbbb2ceaF5eHj7++OMDrrDvueceXHTRRb28n+eeew5ZWVns+//5z3+wfv16jBw5EkajEb/85S97/R55q/SawWDAsGHD8MUXXyTeczqduPLKK5GWlsb2+Xe/+10iXEX7OXXqVNxyyy3MC6P96Avynn/1q18hPT2dPeh/eo0WC/HQHu0Tnd++mDFjRuJ9CpN9++23iEajuP/++5GTk4Pk5GRccsklbB/7w1tvvcXO1Y4dO9j3/vznP6OgoIDt67nnnssWLT3PGV0f5K0kJSXhgQceSGxn3bp1bFz1ej0LrZL3ORDPgn6bxpjOO53b8847jxn1/s4RHSPtZxzt7e1QqVRoa2uD3W7HwoULYbVa2XjS/w0NDQMKP3/22WdsTOn3aYx7ijlVVlay807ng46ZrmuHw8Heu/jii1l48tRTT2Xn/q9//et+nnVTUxO7ts1mMwoLC/HUU08ltk3XAp1jGh+63mkhs2HDhn73mfaJriEaTzrHI0aMSJwLOjd0H9C40Ln/xS9+Ab/fD6/Xi5NOOontA+0fPej/QQwQJOs1iGMDBQUFsccffzy2YcOGGM/zsZaWlsR7M2fOjD311FO9Pr98+fJYRkYG+/+JJ56IZWdnD/i3vF5vTKvVxioqKhKvjRs3LvbKK6+w/1NTU2MrV65k/9tsttjGjRv73c5zzz3H9vXJJ5+MCYIQ+/e//x1LS0uLRaNR9n5OTk7ss88+S3z+7rvvjl144YXs/+rqaprFYtdee23M7/fHPv3005hCoYidfvrpsdbW1lhDQ0PMarXGVqxYkfgtjuNi//jHP2KhUCj26quvxvR6fayzs5O9v2jRotg111wT83g87Pvjx4+P/ec//+n13UceeSQWDodjPp9vv2P5/e9/H5s4cSL7bltbW2zy5Mmx3/3ud732lb7bH/p7/5lnnmFjWllZGXO73bEzzjgjdtFFF+33+WeffZZ9bu/evey9f/7zn2w/6uvrY4FAgB3T4sWLe33vqquuYsewZcuWmFwuj+3cuZO9P2nSpNiLL77I/qff/Pbbb/vd357XTnyc6Hw1Njay8zls2DB2TfWHyy+/PHbXXXclnj/22GOxBQsWsP87Ojpib775Jru+XC5X7Oyzz2bj2d91TGMydepU9n97ezu7Ht944w02tjTGNF7xz9K5WbZsGTsfNDbTp0+P3Xzzzb32v+d11nc86PPXXXcdu842b94cS0pKin3xxReJa5Kuu6VLl7Jr+M4772Tnvz988sknsTFjxsTsdju7xum8NzU1sfd+9atfxU499VR2/ujYFy5cyLbV3/kexMAxaOSOEXz99dfMWNDNThg6dCi70Qdq5O6///4D3pgHAhmbe++9l/1Pxo4mGZqcCFlZWcxAOJ3Og26DJiqaoOOg79Pk0tzcPGAjR8YsDrPZzIxXHGeeeWbs4YcfTvxWTwNKoImZJnVaENBk39N4vfzyy7FZs2YlvkvHdDDk5+ezia7nhEb7/12N3Jw5c9iiJY7du3ezMabPxD//0EMPxYqLi5lBi4MMzOeff554TpNo3+/1/Dydg/jihCbzP/zhD4nr6EDoz8gtWbIk8fyOO+5gi4/+QONJ5yqOKVOmxF544YV+P0sGxWg0HtLI0fd7Xr80xrR/fa/5ON55553YqFGjBmTk6urqYlKplBmeOMj4XHrppYlrcu7cuYn3ysvLY0qlst/fJcNYVFTEFg+RSKTX/qrV6ti+ffsSr33zzTex3Nxc9v+gkfvuGAxXHiOg8BMxzSgUQ7jgggt6hSx5nmfkh56g5zKZjP1PYZzm5ubD+k36DQpvEl5++WWWv6MQZzx8RiFLCrXNnDmThd8OhNTU1MT/8e9T0n+goNBOHBT26vu857YoDElhqDho/yj0U1tby84HhSopnEaPa6+9loXQ4qCQ6MFA26Ht9d32d0V/26PwGeX34njooYdwww039Aod0rGcccYZieMoLi5mBIqe3+t7zuPn6JlnnmHhVQr7jR8/Hh9++OGA9/dA2+yL2bNnw+fzYe3atSwsuGXLFra/BHqdzjsdK4XzKIxLYcVDsRzpXPUcHxrjns/p2BcvXszGn7ZLIe+Ojo4BHRdtm8KUFIqMg/avsbHxgMdOodr+8rYUMqVQKo0ZhSyvueYauFwuFrKlYx87dmxi3E488UT2+iC+HwaN3DEAitu//vrr+Oqrr9jNRg/Kj23dupU9CBTnpwmlJ6qrqxOT6Ny5c1nu40C5hP4wb948dhPSJEXGjoxeHDRBvvfee8xIkPGjnMV3AbE86eaPo6WlBd8HNDH1zNVQLobyZzQhUh6SJj6aVOlBk095eXnisz2NY3+g7ZCB6bvtgaC/bfe3PVqs9DTixC6kvB0tKuKgY6G8Zvw46EGTLk3wh0JRUREbSxq3X//61zj77LNZTuhIggwuXQ/0O/SgvFvcgPz9739nrE0ygHT+V65cyV4/VLMUWpzU19cnntPnez6/66672DkmpjBt93//+1+vbR5sbGkcKKfpdrt7jcVAzmd/uOmmm7Bx40bs3LmTLShooUKLU1qQ0fUWHzPKv8YXCoe69gZxYAwauWMA7777Lps46KYhg0OPXbt2MZo3kVEIRAQgggYRC+jmppuLDCGtbuOT2/XXX4/zzz+fkQpCoRCbGF999VVGYugP5AVSicEdd9zBJgEyegT6LtUw0U1Kn6GVs1T63S61UaNGsX0gL4sM8JtvvonvA5q8iT1K23vjjTfYeTr55JPZJEme8G233cYmQSJvEFmBFg4DBZ07Mjhk+MlY3nfffQmSzKFARAs6R1VVVb22R2NEixGa7GiipnEkQxcHkRw++eQT5hnEGbVEWPjtb3+bMJC0P7TgGAho8qfP076QN0H4rmN3MNCCiMhLdJ30XByRIaHJnn6brql77713QNs75ZRTmIF4++23mQdFY9xzQUTbJcIGkVJooUOGpSdo4dDz3PcELRqmTJmC3/zmN+ye2LZtG/N4Bzq2PUGkKDLgdP3RAk6pVLLzS4+rr76akVLi0QPaz08//TSxf1RbeSDi0SAOjEEjdwyAwpKXX34589binhw9KCxCkwjd9AsWLGDGij5HNzpN7JdeeikLl8RBE0M8lEKTDLHz3nnnHcY6OxBogiL2Ixm7npPvkiVLGOOODBwxHmk/vgv++Mc/MmN
|
2025-10-21 09:20:43 +08:00
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 504x288 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2025-10-22 13:27:28 +08:00
|
|
|
"auc_df = pd.read_csv(f'results_{model_name}/df_both.csv')\n",
|
2025-10-21 09:20:43 +08:00
|
|
|
"plt.figure(figsize=(7, 4))\n",
|
|
|
|
|
"plt.scatter(auc_df['count'], auc_df['auc'], \n",
|
|
|
|
|
" c=auc_df['color'], s=24, edgecolor='white', linewidth=0.65)\n",
|
|
|
|
|
"plt.axhline(0.5, color='k', linestyle='--', linewidth=0.75)\n",
|
2025-10-22 13:27:28 +08:00
|
|
|
"plt.title('AUC vs number of tokens in validation set')\n",
|
2025-10-21 09:20:43 +08:00
|
|
|
"plt.xscale('log')\n",
|
|
|
|
|
"plt.ylim(0, 1.05)\n",
|
2025-10-22 13:27:28 +08:00
|
|
|
"plt.xlabel('Number of tokens in validation set')\n",
|
2025-10-21 09:20:43 +08:00
|
|
|
"plt.ylabel('AUC')\n",
|
2025-10-22 13:27:28 +08:00
|
|
|
"plt.savefig(f'results_n_embd_120_n_layer_12_n_head_12_learnable/fig_auc_vs_data_size.png', dpi=600)"
|
2025-10-21 09:20:43 +08:00
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": null,
|
|
|
|
|
"id": "22ff1f44",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
2025-10-22 13:27:28 +08:00
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAFbCAYAAAAtA38vAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAALEwAACxMBAJqcGAAApERJREFUeJztnQeYVMXyt4tddsk5ShQQQcAM5gyKouJVEcWsKAbMXjFjzmLEhGBCwYABEyJ6xYhgTkiUnAVEcpzvedt/zdeMywa2z+zu2XqfZ2Vndjxnzjnd1b+uqq4uk0gkEmIYhmEYhlHCySjqL2AYhmEYhhECEzWGYRiGYcQCEzWGYRiGYcQCEzWGYRiGYcQCEzWGYRiGYcQCEzWGYRiGYcQCEzWGYRR7br75Zjn11FNz/Nvo0aOlUaNGaf9OxZ0zzzxTbrjhhqL+GoaRVkzUGEaaOeigg6RGjRqydu3af70/cODAXAdsyko98sgj0q5dO6lUqZL72wknnCC//PJL2r5/aRIC69atc4KqZcuW7n5vu+22cvbZZ8v06dOTz6x8+fJSpUoVqVq1quy+++5y9913/+vZptKrVy9p1aqVZGRkyHPPPfevvz/44INSv359d0zOl9fx0g334aOPPirqr2EY/8JEjWGkEQbDzz//XMqUKSNvv/12gf//Sy+9VB5++GEnbJYsWSKTJk2S//znP/Lee+8F/64bNmyQ0k63bt3ccxoyZIgsW7ZMfvrpJydcPv744+Rn+vfvL8uXL5d58+ZJv3795OWXX5YuXbo4Aboldt55Z3n88cdlt912+9ffRo4c6YQR55gxY4b88ccfctNNN0lc4L5s2rSpqL+GEVNM1BhGGnnhhRdkr732ch6B559/vkD/7+TJk+Wxxx6ToUOHyiGHHCLlypWTihUryimnnCLXXHNNvo7x4YcfOg9BtWrV5MILL5QDDzww6R3CY7DvvvvK5ZdfLrVq1XIeCgby008/XerUqSNNmzaV22+/PTkgpYaEEGyINRVDeDGuvfZa2WOPPZzH4ZhjjnFCTPn6669ln332kerVq7tBHq+UMm3aNPfd8IAceuih8ueff+Z5bXfeeafUrl3beRFeeukl994333wj9erVk40bNyY/98Ybb7jz5QWeiFGjRsnw4cOlQ4cOUrZsWXffevfuLT179vzX5/HkcM2IoDFjxuQqNDlGx44dnZcnFdoFx2/btq3z6N144405enOUL774InkfGzduvNlnly5dKkceeaS7j3vuuadMnTp1M4HM59XDhNhWeLYIuhNPPNH9v4gvBB2cdtppMnPmTDn66KOlcuXKcu+99+b5PLkv119/vWtftFmEmmFEgYkaw0izqEGE8MOMfMGCBfn+f5m5E25CJGwNCAMGqrvuuksWL17sxM1XX3212WfGjh0rzZs3d9+LQejiiy92woZB6NNPP3Xf/9lnn833Ofn8M88847wYiIJLLrnEvT9nzhw32BLqQejcf//9cvzxx8uiRYvc308++WQ30PKdGdTzEoDz5893n+W4fJbwzsSJE50YQaAh5pTBgwc7oZYfUcO9ZuAvCE2aNJH27dtvJhIKwm+//baZ6OJ3ngfPLBU8OUcccYR7Tty7H3/8UXbZZZfk3/Ea4eVB3Gy33XbumSrcGz7P/ed+E8Zcs2ZN8u+IOd7Tv+MRXL9+vbt/XOM777wjK1askD59+uT5PIH/b8CAAc6rhUA2jCgwUWMYaYIZNYNQ9+7d3YDdokULF9bILwxq22yzzVaf//3333ez/+OOOy4pMMjb8GnQoIEbIPl7dna2GxQRQczW8YBceeWVbnDKL8zqNf/ntttuk1dffdV5TV588UUXouGHvBK8MQgBviNeADwsfB5v1AEHHOC8Anmhn8fDwwDLueCMM85w5wMGXMQkg3SU95v76HulCgJCAY+Qor8jBlKh/XTq1El69OghWVlZTsD5oubYY491wozniZBGxCh42fg8f+O5kreDEFRoo4hgjnvFFVc4wYM3Jidye54K3knaH+fjmIYRBSZqDCNN4EE47LDDXIgEGFh9DwTGnpmwD691AGAAwuOxtcydO3czrwOhotRVQ/7f8Xxwfn9Wze/MyvOLfzz+X47HcRF3r732mgtV6A+ij+vjexJ2QQj5/29u5PR5jqODN16FlStXOqGz//7750usFOZ+c49q1qzpfidEoz8Itrzgc3///Xfytf6OsExl1qxZThxvCV+0EvZBMCl4U3bYYQcnmrj/eOT8MJ//7BAqtBW9p6nk9jxzOp5hRIWJGsNIA6tXr3YDKiEcBhp+WOFCnoLmKuDS11U1fm6JDujkYMyePVu+/fbbrfoODOT8/37Cpv9ahY6C+EJQMWApDMoNGzZ0vyMiVq1atVkIKKdB1/9/OR7HZYDDi/PXX38lfxAd5AbxPQmX8Nr/f3Mjp8/jLQG+79577+1yafAycd78gAdk3Lhx/7pHecE1f/fdd048AUJCf3jGeYE3Q9sE8Dt5QYisVLiPfp5MfiE0Ri4MbZJ7x/1H3PjJzf6zI4+K+6D31G8n+j229DyV1P/HMKLARI1hpIG33npLMjMzZfz48S4EwM/vv//uBj7yToCkTPJVGEgZXFjZhPA56aST3N9ZVkxyL6EGkjBZbkxIgBARq2WAJFHCRDlBSIal33wXknlJOs5JiCh8X0Jl5GEQ+kDcPPDAA8nkYMIcn332mRMQzPIJU+UUluCaET99+/Z14QyOq94TQkGEo7gOromBExFH6IJcEK6RGT+fzQv9PAP2u+++6/JBFHJoGMS5fsJv+RU1hFEI4SBSuGfchyeffNLlCaXCNSJaSYgm5EMoZkvos+M5473id03A5rsOGjTI3TfEAcnZhG5ygpASuT+IE74fITM/xLQluA48gySA8//deuutm3mHgGtGCPL3hx56yIX2SHIHRJaf7Jvb8zSMtJIwDCNyOnfunLjiiiv+9f4rr7ySqFevXmL9+vXu9aBBgxJt2rRJVKlSJdGiRYvEXXfdldi4cWPy85s2bUo89NBD7jMVKlRINGjQING9e/fEr7/+6v5+6623Jk4++eQtfo8RI0YkWrZsmahatWriggsuSOy1116JF154wf3t2WefTey7776bfX7JkiWJU045JVG7du1Eo0aNErfccstm3+fCCy9MVKtWzX3XAQMGMM1PXsuBBx6YuOaaaxIdOnRw13PUUUclFi1alPx/v/7668QBBxyQqFGjhjt+ly5dEjNmzHB/mzp1amK//fZLVKpUKdGpU6dE79693ffIiU8++STRsGHDxO23356oVatWonHjxslrUlauXOm+w+mnn57rczrjjDMS119/ffL12rVrE3379nXXV7FixUSTJk0SPXv2TH5PrrFcuXKJypUru59ddtnFfY/Vq1fneh7+P+6V/8N1KP369UvUrVvXfeczzzwzsWbNmi0e67PPPkvsscce7rM8o+eeey7Ha9H7BBs2bEicddZZ7v+pX79+4p577kk0bdo0MWrUKPf3m266KXH88ce7tqXX9d133yWP9dZbb7n7zLO/77778nyeXO/TTz+d6z0xjBCU4T/plVGGYUQFOTvUsSFXIi/wDJAnwfLngw8+OPh3YRkvM/hzzjlHigPknjz11FPOA2PkDku6p0yZkkywNoySQtmi/gKGYYTDX7qcE4QHqFdSoUIFue+++1z4Q0MKceb11193OR3U9zEMI76YqDGMUgRF4Vh1RU5HmzZtXH4NAifO4DEiP4UkYVbxGIYRXyz8ZBiGYRhGLLBpi2EYhmEYscBEjWEYhmEYscBEjWEYhmEYscBEjWEYhmEYscBEjWEYhmEYscBEjWEYhmEYscBEjWEYhmEYscBEjWEYhmEYscBEjWEYhmEYscBEjWEYhmEYscBEjWEYhmEYsSCWoubss8+WunXrSrt27XL8O9tdXXLJJbLddtvJTjvtJN9//33av6NhGIZhGGGJpag588wz5YMPPtji30eMGCGTJ092PwMGDJALLrggrd/PMAzDMIzwxFLUHHDAAVKzZs0t/n348OFy+umnS5kyZWSvvfaSv/76S+bNm5fW72g
|
2025-10-21 09:20:43 +08:00
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 576x360 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
|
"import numpy as np\n",
|
|
|
|
|
"import matplotlib.patches as mpatches\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"chapters = auc_df['ICD-10 Chapter (short)'].unique()\n",
|
|
|
|
|
"chapter_data = {}\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"for chapter in chapters:\n",
|
|
|
|
|
" if chapter not in ['Technical', 'Sex', 'Smoking, Alcohol and BMI']: # Skip non-disease chapters\n",
|
|
|
|
|
" chapter_data[chapter] = auc_df[auc_df['ICD-10 Chapter (short)'] == chapter]['auc'].values\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"fig, ax = plt.subplots(figsize=(8, 5))\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"chapter_colors = {}\n",
|
|
|
|
|
"for chapter in chapter_data.keys():\n",
|
|
|
|
|
" chapter_rows = auc_df[auc_df['ICD-10 Chapter (short)'] == chapter]\n",
|
|
|
|
|
" chapter_colors[chapter] = chapter_rows['color'].iloc[0]\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"positions = np.arange(1, len(chapter_data) + 1)\n",
|
|
|
|
|
"boxplots = []\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"for i, (chapter, values) in enumerate(chapter_data.items()):\n",
|
|
|
|
|
" bp = ax.boxplot(values, positions=[positions[i]], patch_artist=True, \n",
|
|
|
|
|
" widths=0.6, whis=[2.5, 97.5], showfliers=True, \n",
|
|
|
|
|
" boxprops={'linewidth': 1.25, 'facecolor': chapter_colors[chapter], \n",
|
|
|
|
|
" 'edgecolor': chapter_colors[chapter]},\n",
|
|
|
|
|
" medianprops={'color': 'black', 'linewidth': 1.5},\n",
|
|
|
|
|
" whiskerprops={'color': 'gray', 'linewidth': 1},\n",
|
|
|
|
|
" capprops={'color': 'gray', 'linewidth': 1},\n",
|
|
|
|
|
" flierprops={'marker': 'x', 'markerfacecolor': 'none', \n",
|
|
|
|
|
" 'markeredgecolor': 'black', 'markersize': 3, 'alpha': 0.3})\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" boxplots.append(bp)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"ax.set_xticks(positions)\n",
|
|
|
|
|
"ax.set_xticklabels([chapter for chapter in chapter_data.keys()], rotation=45, ha='right')\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"ax.set_ylim(0, 1.025)\n",
|
|
|
|
|
"ax.axhline(0.5, color='black', linestyle='--', linewidth=0.75)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"ax.yaxis.grid(True, linestyle='--', alpha=0.7)\n",
|
|
|
|
|
"ax.set_axisbelow(True)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"ax.set_ylabel('AUC')\n",
|
|
|
|
|
"ax.set_xlabel('ICD-10 chapter')\n",
|
|
|
|
|
"ax.set_title('AUC, grouped by ICD-10 chapter', y=1.05)\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"plt.tight_layout()\n",
|
|
|
|
|
"plt.grid(axis='x', visible=False)\n",
|
2025-10-22 13:27:28 +08:00
|
|
|
"plt.savefig(f'results_n_embd_120_n_layer_12_n_head_12_learnable/fig_auc_by_icd10_chapter.png', dpi=600)"
|
2025-10-21 09:20:43 +08:00
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
2025-10-22 13:27:28 +08:00
|
|
|
"execution_count": 50,
|
2025-10-21 09:20:43 +08:00
|
|
|
"id": "e90e7cbf",
|
|
|
|
|
"metadata": {},
|
2025-10-22 13:27:28 +08:00
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"df_256 = pd.read_csv(f'results_n_embd_256_n_layer_16_n_head_16/df_both.csv')\n",
|
|
|
|
|
"df_120 = pd.read_csv(f'results_n_embd_120_n_layer_12_n_head_12/df_both.csv')\n",
|
|
|
|
|
"df_delphi = pd.read_csv('delphi_auc.csv')\n",
|
|
|
|
|
"df_delphi['auc'] = df_delphi[['AUC Female, (no gap)', 'AUC Male, (no gap)']].mean(axis=1)\n",
|
|
|
|
|
"df_delphi.rename(columns={'Name': 'name'}, inplace=True)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 51,
|
|
|
|
|
"id": "e5e7676a",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"df_256 = df_256[['name', 'auc']]\n",
|
|
|
|
|
"df_120 = df_120[['name', 'auc']]\n",
|
|
|
|
|
"df_whole = df_256.merge(df_120, on='name', suffixes=('_256', '_120'))\n",
|
|
|
|
|
"df_whole = df_whole.merge(df_delphi[['name','auc','N tokens, training', 'Colour','ICD-10 Chapter']], on='name')\n",
|
|
|
|
|
"df_whole.rename(columns={'auc': 'auc_delphi'}, inplace=True)\n",
|
|
|
|
|
"df_whole.to_csv('model_comparison_auc.csv', index=False)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 52,
|
|
|
|
|
"id": "75f66e69",
|
|
|
|
|
"metadata": {},
|
2025-10-21 09:20:43 +08:00
|
|
|
"outputs": [
|
|
|
|
|
{
|
2025-10-22 13:27:28 +08:00
|
|
|
"data": {
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAboAAAEcCAYAAACxsnF2AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAALEwAACxMBAJqcGAAAvGJJREFUeJzsnQWcXNX1x7/jvruz7pJk4+4eEkgCAYJrgeJaKNYWqpQKtKWlFPsXd3dJCAECAeLu2STr7jvu8//cO1mLE5Ls7nR/dJp9M/PezHn3zf29c+45v6MIh8NhetGLXvSiF72IUii7+gv0ohe96EUvenE80Ut0vehFL3rRi6hGL9H1ohe96EUvohq9RNeLXvSiF72IavQSXS960Yte9CKq0Ut0vehFL3rRi6hGL9H1ohe9+EE46aSTeOaZZ370e0tLSzGbzQSDwWP8DXvRi87oJbpeHFOIic1qteL1eg874X399ddkZma2bYuSzv/85z8MHToUk8kkX7vgggvYvHnzIT/zH//4h9zHYrGQl5cntzsiNzcXg8EgJ1XxmDNnTqfXCwsLOeOMM+T+iYmJ/PKXv+RE4L777kOhULBy5cr9nr/sssv2e7947+7du9u2Fy1axPTp0+X3TkpKYsaMGXz00UeH/MwXXngBlUrVdi7E+brqqqsoKCjgRCM7OxuHwyG/Ty96cTzRS3S9OGYoLi7m22+/lRPy4SbcA+HnP/85jzzyiCS7xsZGOfmeffbZfPrpp4fcTxDkSy+9RFNTE5999hmPPfYYb7zxRqf3fPzxx3JSFY/PP/+87Xmfz8fs2bOZNWsW1dXVlJeXH5BkjjVav3N8fLz894finXfekTcBV1xxhfzONTU13H///dLOw2HSpEnyPLS0tPDFF1/Im4AxY8awZcuWo7SmF73o3uglul4cM4gJe+LEiVx55ZW8+OKLP2jfXbt28fjjj/P6669L0tHpdBiNRn7yk59wzz33HHJf4YGNHj0atVrNgAEDOOuss/j++++P6HOFh5Oens6dd94pvUi9Xs/w4cMP+N6bbrqJu+++u9Nz4rP+9a9/yb//9re/kZGRIT0s8T2+/PLLg36uuCGoqqqSpC5IWRDuDyFJ8X1/97vfce211xIbG4tSqZQe3dNPP33ExxGeVN++fXniiSfkvsKTbMWKFSuYPHkycXFxjBgxQnrfBzt/U6ZM4Wc/+5n8HgMHDtzP7pKSEvkecV6EN11fX992YyRuigKBwAGPLTzxhx56SI6HOPZFF12Ex+Npe13Y2q9fP3mzMH/+fCorKw95bebk5JCQkMCf/vQneWxB8gKrVq2S5C9sTUtLk7Z0HA/xHcU49enTR3r8v/jFLwiFQkd8nnvR9eglul4cM4jJRBCTeIiwmvAyjhRichShyvHjx/+o7yBIQJDIkCFDOj0vvpMI74mJduPGjZ0mdDHpnXbaaXISEyHWg4VKL7nkEt588035GQLCgxTe4cUXX8zOnTulJ7l69Wrsdru0Xxz3YBA3AmeeeSYXXnih3D4ST6wV4rPKyso4//zzOVY499xz5XkTqKio4PTTT+e3v/2t9KwF2Zx33nnU1dUdcF8RehWEKQjsj3/8ozyW2K8Vr732Gs8//zy1tbWSQMTxjhRvvfWW9NKLiorYtGmTJFaBr776invvvVe+Lm4YBImJcTgQtm3bxs0338yrr74q3ys8WWFjR8J/+OGH5fdfvny5vBYF+XfE+++/z5o1a1i3bh0ffvghzz333BHb0IuuRy/R9eKY4LvvvpN37mLiFmEwMfGJCe5I0dDQIO+mfyyEVyLutsW6UyvEBCe8B/H9Zs6cydy5c2lubpavibCf8Khuu+026RGICV54aQfysKZNmybv7lsJQYQPhScgPEIxWYp1STGp+v1+SXLiHBwILpeLt99+m0svvRSNRiMJ64eEL8W5EjgW56sVwoZWcnrllVeYN2+efAhPUYR2x44dy4IFCw64b3JyMrfffru0RXhdwpvtGG4WY9G/f38ZIhXXx4YNG474e4lxEd9NeG3ixqB1XzGmV199tfTkhff/wAMPSJIS47wvxDiJfadOnYpWq5UhXjGOrRDXq4hEiIiAGLcbbriBb775ptMxfvWrX8nvINYVha0i8tCLnoNeouvFMYHwUIS3JLwiATGJdwxfiklEEEBHiG0xOQqIkJK42/4xEB6VIAwxyYrJrxUibCYmWREKFV6ACFG1kpV4XkyAwqMTk6AITQoi2b59+37HF5Oj8BpaJzlB5MJTFBAhtH//+9+SaMXEL953sFCa8A7E+RBEIiCOsXDhwjaP6WDnSkCcL3GuBH7s+eoI4eGIiVxA3BAIIhbnqfUhbmQO9nkiXNuROIR31dH21NTUtr/FGIj1wSPFwfYVxxef0wqRWCPOS0dPrRXivVlZWZ2O03oOBcRasEhGEp8VExPDr3/967bwais67r+vfb3o/uglul78aLjdbhlCEnfBYrIQDxEKEiHC1jChuBPe925bhKNaJ6uTTz5ZelciPHQ0EKGkBx98sC0EeiiISbk1/CjWfzpO0oeDCF8KD0GQgQjZiZBeKwS5t3q24pjCCzgQxA2AmLDFORHnSiSVCCJr9YAPdq4EAQpSER6TmHjfffddjhUE+QqPVUAc+/LLL5deb+vD6XQedK1UkEvHJiiibEB4YccT4vjiPLdCfD9xgyLOz74Qnq+4tjper61ecevaq1hbFOvENpuNv/71r53sERCh4hNpXy+OLXqJrhc/Gh988IEM3YmwnQgtiYfwiMTE2RqSEyEtsU4jFv7FJCLuogUZtq6r5Ofny3UUQSQi8UGEDkXigQgrCgI7FEQYS9yFL168WCYMdISYlERiSuvxROmBuFsXXp6AyLAU63QiMUHUcwmvTHilgwYNOuBnjRo1Sr4ukkBECFR4O63rZmLdSIQvRUKL8BRF2O9ApCDI+JNPPmk7V+JmQJBi67k69dRT2bFjBy+//LIkQBFSFPYJUhVkJ0hUJMCIpApxTsXkLMK1gmSvv/76Ix43Ya8g0FtvvVWe8z/84Q9t50SsGYp1RvEecd7E6x3JoiPE2ptI1hDfVXiCYuxbvdXjBXGdCNvF+RPnXJyfCRMmHHBdVISGhT3Lli2T14HwujsSmVhTFZ6c8ArFeX/yySf3O4a4bsSarCA8kRksrude9CCIfnS96MWPwdy5c8N33nnnfs+/+eab4ZSUlLDf75fbzz77bHjw4MFhi8US7tu3b/iBBx4IB4PBtveHQqHwv//9b/keg8EQTk9PD1944YXhLVu2HPLzc3Nzw2q1OmwymdoeN9xwg3xN7Dts2LCw0WgMx8fHh2fNmhVevXp1p/3fffdd+X3E95oxY8ZhP+/+++8Xs2T4rbfeantu48aN4XHjxoXNZnPYarWGTz/99HBFRcV++wqbR48evd/z4r3Chs2bN8vt77//PjxlypRwXFxcOC0tLXzNNdeEGxsbO+2zcOHC8NSpU6W9iYmJ8rt/8sknh/zuzz//fFipVMp9xDnJzs4OX3HFFeFt27Z1et+KFSvC06dPl7aIY8+bNy9cUlIiXxOf8/TTT7cdb/LkyeFbbrklHBMTE87Pzw8vWrSo7Tgd39v6fmGXQFFRkTyPrdfHvsjJyQkvXry4bfsPf/hD+Cc/+Unb9pNPPhnu06dP2/kuKys7pN1ZWVnyGhDjJ66tpUuXyte++eab8IABA+Q5Eefzd7/7Xdt3FBDf8ZFHHgnn5eXJ/cW1HggEDnmee9G9oBD/19Vk24te9KJnQmRBCiEA4U32FIiwsfDERahSFMwfDsKDFu8V67C96JnoDV32ohe9iHqI0KXIdhVreSLhaNiwYYcs/+hFdKGX6HrRIyDq4lplqzo+xPpcLzrjxhtvPOC5Es//r0LUvokEEvEQ3plY+/0hSUi96NnoDV32ohe96EUvohq9Hl0vetGLXvQiqtFLdL3oRS960YuoRi/R9aIXvehFL3o0hFqRqG+MKqITwqrRhGiz57///S/RhGgbn157ujeizZ4TYZMQnhBtvqIqGUWoQBxIdaKnItrsEbVVolVPtCDaxqfXnu6NaLPnWNs
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 504x288 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"plt.figure(figsize=(7, 4))\n",
|
|
|
|
|
"plt.scatter(df_whole['auc_delphi'], df_whole['auc_256'], \n",
|
|
|
|
|
" c=df_whole['Colour'], s=24, edgecolor='white', linewidth=0.65)\n",
|
|
|
|
|
"plt.axline((0.3, 0.3), slope=1, color='k', linestyle=(0, (5, 5)), linewidth=0.7)\n",
|
|
|
|
|
"plt.axvline(0.5, color='gray', linestyle='--', linewidth=0.75)\n",
|
|
|
|
|
"plt.axhline(0.5, color='gray', linestyle='--', linewidth=0.75)\n",
|
|
|
|
|
"plt.xlim(0.3, 1.05)\n",
|
|
|
|
|
"plt.ylim(0.3, 1.05)\n",
|
|
|
|
|
"plt.title('AUC_256 vs AUC_Delphi no gap')\n",
|
|
|
|
|
"plt.ylabel('AUC_256')\n",
|
|
|
|
|
"plt.xlabel('AUC_Delphi')\n",
|
|
|
|
|
"plt.savefig('model_comparison_auc_256_vs_delphi.png', dpi=600)\n"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 53,
|
|
|
|
|
"id": "3cf730fa",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAboAAAEcCAYAAACxsnF2AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAALEwAACxMBAJqcGAAAvdpJREFUeJzsnQWYXOXVx3/jvu6azWbj7oEYCQT34u5epPQrUAotUKCFllIoUKAUijsECZ6gSYhvspHdzbr7zI7r97zvrMYDSXZ32P/z3GRn5tq5773vucf+RxEKhUIMYhCDGMQgBhGhUPb1CQxiEIMYxCAGcTAxqOgGMYhBDGIQEY1BRTeIQQxiEIOIaAwqukEMYhCDGEREY1DRDWIQgxjEICIag4puEIMYxCAGEdEYVHSDGMQg9gvz58/n2Wef/dnrVlRUYDabCQQCB/gMBzGI3hhUdIM4oBATW2xsLB6PZ68T3rJly8jIyOj6LEo6//nPfzJ27FhMJpP87YwzzmDjxo17PObSpUs54ogjiI6OZsiQIb1+a2ho4JxzziEtLU3+fvjhh7Ny5cpe67zyyitkZ2fLY55yyim0tLRwKPDHP/4RhUKx0/mI788///yd1hfrFhcXd33+9NNPmTt3LhaLhcTERObNm8fixYv3eMznn38elUolFYxYcnJyuOSSSygsLORQIysrC7vdLs9nEIM4mBhUdIM4YCgrK+Pbb7+VE/LeJtxd4cYbb+TRRx+Vyk4oGzH5CsXz0Ucf7XE7oaAuvfRSHnrooZ1+ExPptGnTWLNmjdznRRddxPHHHy+/FygoKOCqq67ixRdfpL6+HqPRyLXXXsvBhlDq//vf/4iLi5P/7y/eeust+RJw4YUXUlVVJc/9nnvu4YMPPtjrtrNmzZLyW61WvvjiCwwGA1OmTGHTpk0/UZpBDKKfQzCjDGIQBwJ/+tOfQocddljo5ptvDh1//PG9fps3b17omWee6fXd0qVLQ+np6fLvwsLCkFKpDK1cufInH//zzz8PZWdn73U9i8USWr16tfz79ttvD51zzjldvxUXF4c0Gk3IZrPttN3VV18d+s1vftPru5NOOin0t7/9Tf794IMPhtLS0kJmszk0fPjw0BdffLHbc/j6669Der0+9NJLL4Xi4uJCHo+n67e77747dN555+20jXhci4qKQsFgMJSZmRn661//Gtpf/Pe//w0dfvjhO30vxuv000/v+rx8+fLQrFmzQtHR0aHx48fLsdrVWIr9iTG/7rrrQlFRUaERI0b0kluse+edd8p1xHU56qijQo2NjfK30tJSKZPP59vluYqxfOihh0Ljxo2T+z7zzDNDLper6/enn346lJubG4qNjQ2deOKJoerq6t3K/cILL4SysrLktb7nnnvkvsX9IiDuuZkzZ0pZU1JSpCw9x0Oc46OPPhrKyckJxcfHh2699dZQIBDY52s+iL7HoEU3iAMGYZmcd955chFuNWFl7Cu+/PJL6aqcPn36QT3H9evX4/V6GTZsWJdFN2HChK7fc3Nz0Wq1u3TlCRfo66+/Lq0xgdbWVj777DPOPvtstm3bxuOPP86qVatob2+X8u/oRu2JF154gRNPPJEzzzxTft4XS6wT4liVlZX86le/4kDhtNNOk9a4QHV1tbR677zzTmkFP/zww5x++uk0NjbuclvhehXXrampiT/96U9yXz3dv8I1/N///le6kcW1F/vbV7zxxht88sknlJaWkp+fL12vAl999RW33367/L22tla6nsU47AqbN2+WVvrLL78s1xWWrJCxE8J1+sgjj8jzX758ubwXn3jiiV77ePfdd1m9ejVr167l/fff57nnnttnGQbR9xhUdIM4IPjuu+8oLy+XE7dwg4mJT0xw+4rm5mZSU1MP6jnabDYuuOAC7r77bhmvExAuvM6/OyE+C2W1I+bMmSPdsp0KQbgPhRtQxP/EZCnikmJS9fl8UsmJa7ArOJ1O3nzzTc4991w0Go1UWPvjvhTXSuBAXi8hQ6dyeumllzjuuOPkolQqOeqoo5g6dSoff/zxLrdNSkripptukrKcddZZjBgxope7WcQAhw8fLl2k4v4QLxv7il//+tfy3ISLV7wYdG4rlJZwV0+ePBmdTscDDzwglZRwn+8IMU5i29mzZ8uXGOHiFePYCXG/zpw5E7VaLcdNuLK//vrrXvv43e9+J89BxBWFrK+++uo+yzCIvsegohvEAYGwUBYtWkRCQoL8LCZx8V0nxCQiFEBPiM9ichSIj4+Xb9sHCy6XS052YkITlkAnREKGUIA9IT6LBI8dISZHYTV0TnJCkQvrVUBYiP/4xz9kIomY+MV6NTU1uzwXYR2I6yEUiYDYx5IlS7ospt1dKwFxvcS1EjiQ10tYOGIiFxAvLEIRx8TEdC3iRWZ3x0tPT++lOIR11VP2lJSUrr9FDLQzProv2N22Yv/iOD3HUVyXnpZaJ8S6mZmZvfbTeQ0FhPV+wgknyGNFRUVxxx13SOuuJ3puv6N8g+j/GFR0gzggSkS4kMRbsJgsxCJcQRs2bJCLgHgT3vFtW7ijOierhQsXyqQK4R460BCWlkhqEa7Rf//7371+GzNmTNc5CpSUlMj1hQWyKwj3pbAQhDIQLjvh0uuEUO6dlq2Y+IUVsCuIFwAxYYtrIq6VSCoRiqzTAt7dtRIKUCgVYTGJifftt9/mQEEoX2GxCoh9C8u3ra2ta3E4HNx222273FYol55NUETZgLDCDibE/sV17oQ4P2HpiuuzI4TlK+6tnvdrp1UscM011zBy5EiKiorkS87999/fSx4B4So+lPIN4sBiUNEN4mfjvffek6474bYTriWxbNmyRU6cnS454dIScZoff/xRTiLiLVoow864Sl5enoyjCEUiyg5ELMftdvPaa6/x4IMP7vH4wWBQriuUhdi3+FtsLyC+E65B4TYTCka44npCWFMiPibckWKyvOuuu2SMaVcWncCkSZOk1Xr55Zdz9NFHS2unM24m4kZCSer1enm8HY/VqRREDOjDDz/sulZC0Qql2HmtjjnmGLZu3SozQcX5C5eisDKEUhXKTijRv//979x7773ymorJWVwDoWSvvPLKfR43Ub8mFOgNN9wgr7lw6QqI0gZxTUScUawjrqf4vaey6AkRexOZsuJchSUoxr7TWj1YEPeJkF1cP3HNxfWZMWPGLuOiYvyFPD/88IO8L4TV3VORCTe1sOSEVSiu+5NPPrnTPkRGr4jJCoUnMoPF/TyIAYS+zoYZxMDH0UcfHbrlllt2+v71118PJScnd2XV/ec//wmNHj1aZj2KbLkHHnigV/aayCb8xz/+IdcxGAwyg1Fk2m3atGmPxxcZgeJW7rmIbD+BZcuWyc9ifyaTqWv55ptvurZ/+eWXZRaj0WiUWZTNzc17PJ7I2hP7fOONN7q+27BhQ2jatGkys1BkAYosxl1lAQqZJ0+evNP3Yl21Wh3auHGj/Pz999/L7MiYmJhQampq6LLLLgu1tLT02mbJkiWh2bNnS3kSEhKkzB9++OEez11kSYrsVrGNkFdkIl544YWhzZs391pvxYoVoblz50pZxL6PO+64UHl5+V6zLvPy8kKffvrpbrNte2Z97kvWZWdm5K6yUZ988snQ0KFDu653ZWXlHuUWY9yZdSnurc57QGTAimxRcU3E9fzDH/7QKzO1Z9al2F7c636/f4/XeRD9CwrxT18r20EMYhADEyILUhABCGtyoEC4jYUlLlyVomB+bxAWtFi3M1N3EAMPg67LQQxiEBEP4boU2a7CPX3rrbcybty4PZZ/DCKyMKjoBjEgIJJGOmmrei4izXwQvXH11Vfv8lqJ73+pELVvIoFELMI6E7Hfnpmig4hsDLouBzGIQQxiEBGNQYtuEIMYxCAGEdEYVHSDGMQgBjGIiMagohvEIAYxiEEMaAi2IlHfGFGKThCrRhIiTZ4d2UcGOiJtfAbl6d+INHkOhUyCeEK0+YqoZBTBArEr1omBikiTR9RWXXzxxUQKIm18BuXp34g0eQ60TIKtR7DiCCYeQd+2LxiQV1PQ9EQSIk0ecVNHEiJtfAb
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 504x288 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"plt.figure(figsize=(7, 4))\n",
|
|
|
|
|
"plt.scatter(df_whole['auc_delphi'], df_whole['auc_120'], \n",
|
|
|
|
|
" c=df_whole['Colour'], s=24, edgecolor='white', linewidth=0.65)\n",
|
|
|
|
|
"plt.axline((0.3, 0.3), slope=1, color='k', linestyle=(0, (5, 5)), linewidth=0.7)\n",
|
|
|
|
|
"plt.axvline(0.5, color='gray', linestyle='--', linewidth=0.75)\n",
|
|
|
|
|
"plt.axhline(0.5, color='gray', linestyle='--', linewidth=0.75)\n",
|
|
|
|
|
"plt.xlim(0.3, 1.05)\n",
|
|
|
|
|
"plt.ylim(0.3, 1.05)\n",
|
|
|
|
|
"plt.title('AUC_120 vs AUC_Delphi no gap')\n",
|
|
|
|
|
"plt.ylabel('AUC_120')\n",
|
|
|
|
|
"plt.xlabel('AUC_Delphi')\n",
|
|
|
|
|
"plt.savefig('fig_auc_120_vs_delphi.png', dpi=600)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 54,
|
|
|
|
|
"id": "4fc44bad",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"df_256 = pd.read_csv(f'results_n_embd_256_n_layer_16_n_head_16/df_both_1year.csv')\n",
|
|
|
|
|
"df_120 = pd.read_csv(f'results_n_embd_120_n_layer_12_n_head_12/df_both_1year.csv')\n",
|
|
|
|
|
"df_delphi = pd.read_csv('delphi_auc.csv')\n",
|
|
|
|
|
"df_delphi['auc'] = df_delphi[['AUC Female, (1 year gap)', 'AUC Male, (1 year gap)']].mean(axis=1)\n",
|
|
|
|
|
"df_delphi.rename(columns={'Name': 'name'}, inplace=True)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 55,
|
|
|
|
|
"id": "bf680d14",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [],
|
|
|
|
|
"source": [
|
|
|
|
|
"df_256 = df_256[['name', 'auc']]\n",
|
|
|
|
|
"df_120 = df_120[['name', 'auc']]\n",
|
|
|
|
|
"df_whole = df_256.merge(df_120, on='name', suffixes=('_256', '_120'))\n",
|
|
|
|
|
"df_whole = df_whole.merge(df_delphi[['name','auc','N tokens, training', 'Colour','ICD-10 Chapter']], on='name')\n",
|
|
|
|
|
"df_whole.rename(columns={'auc': 'auc_delphi'}, inplace=True)\n",
|
|
|
|
|
"df_whole.to_csv('model_comparison_auc.csv', index=False)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 56,
|
|
|
|
|
"id": "f415d2e5",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAboAAAEcCAYAAACxsnF2AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAALEwAACxMBAJqcGAAAxOBJREFUeJzsXQV8FGf6fnZn3Xfj7gGCu0uhlJYW6tTdXa73r95d25P2rle7ujstFUqdOhVKcQsQEhLitptk3Wf3/3u/zcaIASGyzdPflMzu7My8881877z2vIJgMBjEMIYxjGEMYxgRCuFAn8AwhjGMYQxjGMcSw4puGMMYxjCGEdEYVnTDGMYwhjGMiMawohvGMIYxjGFENIYV3TCGMYxhDCOiMazohjGMYQxjGBGNYUU3jGH8wbFgwQK8/PLLR71teXk5VCoVeJ7v4zMcxjCODsOKbhg9Tmx6vR4ej6fHCW/dunVITk5uWacSzf/9738YM2YMlEol++7ss8/G7t27uz3mI488wn6jVquRkZHB1tsiPT0dcrmcTaq0nHDCCe2+LykpwSmnnMJ+Hx0djf/7v/9Df+D++++HQCDAxo0bD/n8wgsvPGR72vbAgQMt619//TXmzZvHzjsmJgbz58/Hp59+2u0xX3/9dXAc13It6HpddtllKCwsRH8jNTUVdrudnU9nePrppzFlyhRIpVJceuml/X5+w/jjYljRDaNLlJaW4pdffmETck8Tbme45ZZb8OSTTzJl19jYyCbf0047DV988UW3vyMF+eabb6KpqQlr165lE+R7773XbpvPPvuMTaq0fPPNNy2fe71eLF68GAsXLkRtbS0qKys7VTJ9jfA5GwwG9u/h4sMPP2QvARdffDE757q6Ojz44INMzp4wc+ZMdh0sFgu+++479hIwefJk5OfnYzAhMTER9913Hy6//HIMFvj9/oE+hWH0A4YV3TC6BE3YM2bMYG/fb7zxxmH9tqioCM888wzeffddpnToLV6hUOCCCy7AXXfd1e1vyQKbNGkSRCIRRowYgVNPPRXr16/v1XHJwqEJ9fbbb2dWpEwmw7hx4zrd9rrrrsMdd9zR7jM61mOPPcb+/ve//42kpCRmYdF5fP/9910el14IampqmFInpUwK93CUJJ3vX/7yF1x55ZXQarUQCoXMonvppZd6vR+ypLKysvDss8+y35IlGcbvv/+OWbNmQafTYfz48cz67ur6zZ49GzfeeCM7j5EjRx4id1lZGduGrgtZ0yaTqeXFiF6KulIeZ5xxBnvRiYqK6lYOunb0wtDW8q+vr2f3j9FoZOuff/45JkyYwOQhuXbt2tWy7cMPP8yuA51fXl4ePv7440Pku+2229h5tL1GYbhcLlxyySXMkzFq1Cj85z//aeep6M3+u7t+w+h/DCu6YXSr6Egx0UJuNbIyegt6uGlymDZt2lGdAykBUiKjR49u9zmdE7n3aKLduXNnuwmdXJsnnXQSc1uSi7UrV+l5552HVatWsWMQyIIk6/Dcc8/F/v37mSW5efNm2Gw2Jj/ttyvQi8CyZcuwYsUKtt4bSywMOlZFRQXOOuss9BVIqdB1I1RVVeHkk09m1hRZ1v/9739x5plntiiNjiDXK03kpMAeeOABti/6XRgrV67Ea6+9xpQPKSXaX19CIpGwMXj77bdbPqMXpkWLFrEx3759O7MKX3jhBTQ0NOCaa67B8uXLW9zrdO4kO1m4f/vb35hFTy8hbeXLzMxk9/O99957yPFJZlLa5AL/9ttv251Hb/ff3fUbRv9jWNENo1P8+uuv7M2dJm5yg9GDSxNcb0ETUEJCwlGfB71xBwIBFncK45133mETEZ3fcccdhyVLlsBsNrPvyO1HFtXNN9+M6upqNsGTldaZhTV37lxmgYQVArkPyQ1IFiFZRzRx7t27Fz6fjyk5ugadwel04oMPPsD5558PsVjMFNbhuC/pWhH64nqFQTKEJ1eaqJcuXcoWshTJtUuxsi+//LLT38bGxuLWW29lspxzzjnMmm3rbqaxyM3NZS5Suj927NiBvgZZVKTcwi8hb731Fi666CL294svvsiU2/Tp09k40bbkMaCXHAK5gEl+kpXOPycnB5s2bWp3bW666SbmMSAZOuL999/HPffcwyw6elmje6ktetp/T9dvGP2PYUU3jC4tFLKWyCoi0CTe1n1JkwQpgLagdXq4CeQWavuWeyQgi4oUBk0SNJGFQa4hmqDIlXX33Xcz91VYWdHnc+bMYRYdWQbkmiRFsm/fvkP2T0qOLAeaUAmkyMlSJGRnZ+OJJ55gipYmLtqOFGdnINcVXQ9SJATax1dffdViMXV1rQh0vcKuvKO9Xm1BVhy5/wj0QkCKmK5TeKEXma6OR+5aujZhpKWltZM9Pj6+5W8aA4oP9jVIidG+ycVaUFDAknbIagvL8+ijj7aThyzi8DnSPRN2a9JCscqwe5WQkpLS7bFpP2236bh9T/vv6foNo/8xrOiG0WmMgt5qf/rpJzap0fL4448zF2HYTUgZdmRVtcXBgwfZQ00gNxNZV1u2bDmic3j11VdZLCTsAu0ONKmE3/wpHtd2kukJ5L4kS44mT3I5kUsvDFLuYcuW9nnnnXd2ug96AaDJnq4JXSt64ydFFraAu7pWpABpUqQ3fppMP/roI/QVSPmSxUqgfZM1RFZveHE4HF3GSklJtm1qQmUDZMH0N8hSI2uUrDmykineGpaHXI5t5SGrmsaSxuqqq65iL0n0gkPfUQZvW3l6uj/IsqZ7NwxSomH0Zv+D5foNoxXDim4Yh2DNmjXMJURuO3JL0UIWEU2cYZccuWQoTkMuG3qoKaOSlCFZPgRy51x//fVs8qG3cnIdut1u5lYkBdYdyDVJriOKj1AspS1o0qDElPD+qPSA3qbJyiNQvIRcWJR9SPVcZJWRVUpJBZ1h4sSJ7HtKAiEXKL2hh+NmP/zwA3Nf0gRLliK5qjqCJjVSxpQcEb5W9DJASjF8rU488URmldCETQqQXIokHylVUnY08VICzN///nd2Ta1WK3PXkpK9+uqrez1uJC8pUHLL0TWn+FH4mlDMkOKMtA1dN/q+7WTeFhR7o6QaOleyBGnsw9bq0YCSVOjYdA7h8+gu65HOmxQ2KTvKRg2DFM3zzz/PXkzo3iOlTVY/xVLpb7qeFMsj0PU83OxTcsc+9NBDLGZL40tKLYze7P9YXb9hHAWoH90whtEWS5YsCd5+++2HfL5q1apgXFxc0OfzsfVXXnklmJeXF1Sr1cGsrKzgQw89FOR5vmX7QCAQfOKJJ9g2crk8mJiYGFyxYkUwPz+/2+Onp6cHRSJRUKlUtizXXHMN+45+O3bs2KBCoQgaDIbgwoULg5s3b273+48++oidD53X/Pnzezzegw8+SK/fwffff7/ls507dwanTp0aVKlUQb1eHzz55JODVVVVh/yWZJ40adIhn9O2JMPu3bvZ+vr164OzZ88O6nS6YEJCQvCKK64INjY2tvvNV199FZwzZw6TNzo6mp37559/3u25v/baa0GhUMh+Q9ckNTU1ePHFFwf37t3bbrvff/89OG/ePCYL7Xvp0qXBsrIy9h0d56WXXmrZ36xZs4I33HBDUKPRBHNycoJff/11y37abhvenuQiHDx4kF3H8P3REX/729/Y920X+qw7LFq0KJiWlsbupY7XasqUKUGtVhuMj48PnnXWWUGr1cq+u+eee5icUVFRwdtuu43J3Va+8Pl2BbvdHrzwwgvZvkeOHBn8+9//HszMzGz5vqf9d3f9hjEwEND/jkZRDmMYw4gcUHo8EQGQNTkYQNmV5Pb7xz/+MWDn8NxzzzFPBLnyh9r1G0YIw67LYQxjGIMSFNdcvXo1rrjiin49LiXpkHuc3MfkwqbEl9NPP71fz2EYfYthRTeMAQHVxYVpq9ouFJ8bRntce+21nV4r+jxSQcXzlOTx5z//mdGa9Sco/kvlC1QQTmQHVJ5C8eZhDF0Muy6HMYxhDGMYEY1hi24YwxjGMIYR0RhWdMMYxjCGMYyIxrCiG8YwhjGMYQxpELsRdUqJKEW3bds2RBIiTR4
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 504x288 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
2025-10-21 09:20:43 +08:00
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
2025-10-22 13:27:28 +08:00
|
|
|
"plt.figure(figsize=(7, 4))\n",
|
|
|
|
|
"plt.scatter(df_whole['auc_delphi'], df_whole['auc_256'], \n",
|
|
|
|
|
" c=df_whole['Colour'], s=24, edgecolor='white', linewidth=0.65)\n",
|
|
|
|
|
"plt.axline((0.3, 0.3), slope=1, color='k', linestyle=(0, (5, 5)), linewidth=0.7)\n",
|
|
|
|
|
"plt.axvline(0.5, color='gray', linestyle='--', linewidth=0.75)\n",
|
|
|
|
|
"plt.axhline(0.5, color='gray', linestyle='--', linewidth=0.75)\n",
|
|
|
|
|
"plt.xlim(0.3, 1.05)\n",
|
|
|
|
|
"plt.ylim(0.3, 1.05)\n",
|
|
|
|
|
"plt.title('AUC_256 vs AUC_Delphi 1 year gap')\n",
|
|
|
|
|
"plt.ylabel('AUC_256')\n",
|
|
|
|
|
"plt.xlabel('AUC_Delphi')\n",
|
|
|
|
|
"plt.savefig('model_comparison_auc_256_vs_delphi_1year.png', dpi=600)"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
{
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
"execution_count": 57,
|
|
|
|
|
"id": "61df30ab",
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"outputs": [
|
|
|
|
|
{
|
|
|
|
|
"data": {
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAboAAAEcCAYAAACxsnF2AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAALEwAACxMBAJqcGAAAxBlJREFUeJzsXQd8FGX6frb3lk3vhST03pEuooCooNh7OT1PvfPO/9nuznannnrWO3tDVGxYUMQGqPQeSiAhIb3vZrO9zu7/936bbAohCS3ZrHl+v4HM7uzMvPPNfO+87Xl5gUAggAEMYAADGMAAIhT8vj6BAQxgAAMYwADOJAYU3QAGMIABDCCiMaDoBjCAAQxgABGNAUU3gAEMYAADiGgMKLoBDGAAAxhARGNA0Q1gAAMYwAAiGgOKbgAD+I1j1qxZeOONN0552/LyciiVSnAcd5rPcAADODUMKLoBdDux6XQ6uN3ubie8DRs2IDk5ObROJZovvPAChg8fDoVCwb675JJLsH///i6PuX79esyePRsajQbp6entvquvr8fll1+OxMRE9v20adOwbdu2dtt88MEHSEtLY8e88MIL0djYiN7AQw89BB6Pd8z50OdXXXXVMdvTtkVFRaH17777DjNmzIBKpUJMTAxmzpyJr776qstjvvPOOxAIBEzB0JKRkYHrr78ehYWF6G2kpqbCZrOx8+kML730EsaPHw+JRILrrruu189vAL9dDCi6ARwXpaWl+PXXX9mE3N2E2xnuuusuPP/880zZkbKhyZcUzzfffNPl70hB3XDDDXjqqaeO+Y4m0gkTJmDXrl1sn9deey0WLlzIPiccPHgQv/vd7/Dee++hrq4Ocrkcv//973GmQUp9+fLliIqKYv+fKD799FP2EnDNNdegsrKSnfsjjzyC1atXd/vbKVOmMPnNZjN+/PFHyGQyjBs3DgcOHEA4gV5OHnzwQTa24QKfz9fXpzCA3gAxowxgAJ3h4YcfDkydOjXwpz/9KbBw4cJ2382cOTPw+uuvt/ts/fr1gaSkJPZ3YWFhgM/nB7Zt23bSx//hhx8CaWlp3W6nUqkCO3fuZH/fd999gcsvvzz0XVFRUUAkEgUsFssxv7v11lsDf/7zn9t9tnjx4sAzzzzD/n7iiScCiYmJAaVSGcjJyQn8+OOPxz2Hn3/+OSCVSgMrVqwIREVFBdxud+i7f/zjH4Err7zymN/Q43fkyJGA3+8PpKSkBP79738HThRvv/12YNq0acd8TuO1dOnS0PqWLVsCU6ZMCWg0msDIkSPZWHU2lrQ/GvPbb789oFarA7m5ue3kpm0ffPBBtg1dl3nz5gUaGhrYdyUlJUwmr9fb5Tk/8MADgWuvvfa439O10+l0gX379oU+q6urC8hkskB9fT1bX716dWDUqFFMHpIrLy8vtO3jjz8eyMzMZOc3ZMiQwKpVq9pdLzr3P/7xj2yc6Fw6wuFwBK655pqAVqsNDB48OPDkk0+G7uue7v94128AfYMBi24AxwVZJldeeSVbyK1GVkZP8dNPPzFX5cSJE8/oOe7duxcejweDBg0KWXSjRo0KfZ+VlQWxWNypK49coB999BGzxggmkwnff/89LrvsMhQUFDBX244dO2C1Wpn8Hd2obfHuu+/i/PPPx7Jly9h6TyyxFtCxKioqcPHFF+N0YcmSJcwaJ1RVVTGrl6wpsoKffvppLF26FA0NDZ3+llyvdN0MBgMefvhhtq+27l9yDb/99tvMjUzXnvZ3OkHjRWOwYsWK0Gcffvgh5s6dy1y6e/bsYVbhq6++CqPRyCz4xYsXh9zrdO4kO1m4//jHP5jbuKampp18mZmZ7H5+4IEHjjk+yUzejKNHj+KHH35odx493X9X128AvY8BRTeATrFx40aUlZWxiZvcYPTg0gTXU9AElJCQcEbP0WKx4Oqrr2aTDcXrCOTCa/m7BbROyqojpk+fztyyLQqB3IfkBiQXG8WZaOLMz8+H1+tlSo6uQWdwOBz45JNPcMUVV0AkEjGFdSLuS7pWhNN5vUiGlsmVJuoFCxawhc/nY968eSxWtmbNmk5/Gxsbiz/+8Y9MlksvvRS5ubnt3M0UA8zJyWEuUro/6GXjdINc0qTcWl5CyBVNY0147bXXmHKbNGkSGyfaluJ+W7duZd+TC5jkJ1np/LOzs7F9+/Z21+aOO+6AUChkMnTExx9/jPvvv5/Fpull7c4772z3fXf77+76DaD3MaDoBnBcC+Wcc85BdHQ0W6dJnD5rAU0SpADagtbp4Sbo9fp2b7mnG06nk1lQkydPxn333Rf6nBIySAG2Ba1TgkdHkJIjy4EmVAIpcrJeCWQhPvfccyyRhCYu2q66urrTc/n888/Z9SBFQqB9fPvttyGL6XjXikDXi64V4XReL7LiKF5IoBcWUsRarTa00IvM8Y6XlJTErk0LKLGnrezx8fGhvykG2hIfPZ0gJUb7pgSnw4cPs6Qdstpa5HnmmWfayUMWccs50kvG6NGjQ99RrJKsqxakpKR0eWzaT9ttOm7f3f67u34D6H0MKLoBdKpE6K32559/ZpMaLc8++yzy8vLY0pJhR+6dtigpKWEPNYHcTJRUsXPnztN+fmRpUVILvW2T+6othg0bFjpHArmfaHuyQDoDuS/JkqPJk1xO5NJrASn3FsuWJq6//vWvne6DXgBosqdrQteK3vhJkbVYwMe7VqQAaVKkN36aTD/77DOcLpDyJYuVQPsma6ipqSm02O123HvvvcdVkm2bmlDZAFkwvQ2y1MgaJWuOrGSpVBqSh1yObeUhq5rGksbq5ptvZm5nspTpO8r6bStPWyXUGciypnu3BaREW9CT/YfL9RtAKwYU3QCOwRdffMFcQuS2I7cULYcOHWITZ4tLjlwyFKchlw091BQDI2VIlg+B3DmU7UiTD72VUyzH5XJh5cqVeOKJJ7o8vt/vZ9uSsqB909/0ewJ9RpMeuZxIwZD7qC3ImqL4GLkjaTL/+9//zmIknVl0hDFjxjCr9aabbsL8+fPZG3pL3GzdunVMSdIES8freKyWSY3ikV9//XXoWpGiJaXYcq3OPfdcZpXQhE3nTy5Fco2RUiVlRxPvf/7zHzz66KPsmpIFSteAlOwtt9zS43Gj+jVSoOSWo2tOLl0CxZDomlCckbah60nft53M24Jib5QpS+dKliCNfYu1eqoZjnRsOoeW8+gq65HOmxQ2KTvKRm0BKZpXXnmFvZjQ/UHjTK5Bck/T33Q9KZZHoOt5otmn5I59/PHHWcyWxpeUWgt6sv8zdf0GcArooySYAYQx5s+fH7j77ruP+fyjjz4KxMXFhbLq3nzzzcDQoUNZ1mNWVhbLRuM4LrQ9ZRM+99xzbBvKmKMMxmXLlgUOHDjQ5fEpI5BuzbYLZfsRNmzYwNZpfwqFIrT88ssvod+///77LItRLpezLEqj0djl8R555BG2z48//jj0GWXxTZgwgWXWUQYgZTFWVVUd81uSeezYscd8TtsKhcLA/v372fqmTZtYdiRl8iUkJARuvPHGQGNjY7vffPvtt4GzzjqLyRMdHc1k/vrrr7s8d8ryo+xW+g3Jm5qayjIG8/Pz2223devWwIwZM5gstO8FCxYEysrKus26zM7ODnz33XfHzbZtm/XZXdYlZZ92HFf6rCvMnTuXZd7SvdTxWo0fP55lXcbHxwcuvvjiUGbt/fffz+TU6/UsY5jkbitfZ1mqbWGz2QJXXXUV2zdlXT766KMsy7IF3e2/q+s3gL4Bj/45FUU5gAEMIHJABehEBEDWZDiAsivJ7ffYY4/12Tm8/PLLzBNBrvz+dv0GEMSA63IAAxhAWILimqtWrcKNN97Yq8elJJ1NmzYx9zG5sCnx5aKLLurVcxjA6cWAohtAn4CSRlpoq9ou77//fl+fWtjh1ltv7fRa0eeRir/97W8syeOee+5htGa9CYoHU/kCxXXnzJmDCy64oFfYdQZw5jDguhzAAAYwgAFENAYsugEMYAADGEBEY0DRDWAAAxjAACIaA4puAAMYwAAG0K9B7EbUKSWiFN3u3bsRSYg0eTqylfR3RNr4DMg
|
|
|
|
|
"text/plain": [
|
|
|
|
|
"<Figure size 504x288 with 1 Axes>"
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
"metadata": {},
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"source": [
|
|
|
|
|
"plt.figure(figsize=(7, 4))\n",
|
|
|
|
|
"plt.scatter(df_whole['auc_delphi'], df_whole['auc_120'], \n",
|
|
|
|
|
" c=df_whole['Colour'], s=24, edgecolor='white', linewidth=0.65)\n",
|
|
|
|
|
"plt.axline((0.3, 0.3), slope=1, color='k', linestyle=(0, (5, 5)), linewidth=0.7)\n",
|
|
|
|
|
"plt.axvline(0.5, color='gray', linestyle='--', linewidth=0.75)\n",
|
|
|
|
|
"plt.axhline(0.5, color='gray', linestyle='--', linewidth=0.75)\n",
|
|
|
|
|
"plt.xlim(0.3, 1.05)\n",
|
|
|
|
|
"plt.ylim(0.3, 1.05)\n",
|
|
|
|
|
"plt.title('AUC_120 vs AUC_Delphi 1 year gap')\n",
|
|
|
|
|
"plt.ylabel('AUC_120')\n",
|
|
|
|
|
"plt.xlabel('AUC_Delphi')\n",
|
|
|
|
|
"plt.savefig('fig_auc_120_vs_delphi_1year.png', dpi=600)"
|
2025-10-21 09:20:43 +08:00
|
|
|
]
|
2025-10-20 10:14:50 +08:00
|
|
|
}
|
|
|
|
|
],
|
|
|
|
|
"metadata": {
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
"display_name": "base",
|
|
|
|
|
"language": "python",
|
|
|
|
|
"name": "python3"
|
|
|
|
|
},
|
|
|
|
|
"language_info": {
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
"version": 3
|
|
|
|
|
},
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
"name": "python",
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
|
"version": "3.12.9"
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"nbformat": 4,
|
|
|
|
|
"nbformat_minor": 5
|
|
|
|
|
}
|