From 8f44018baeb68b23c6c3672c647f31fce4c27baa Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Mon, 20 Oct 2025 13:47:50 +0800 Subject: [PATCH] update evaluate --- evaluate_auc.py | 49 ++++--------- evaluate_models.ipynb | 162 ++++++++++++++++++++++++++++++++---------- utils.py | 47 ++++++++++-- 3 files changed, 182 insertions(+), 76 deletions(-) diff --git a/evaluate_auc.py b/evaluate_auc.py index 635c3d7..7329ec5 100644 --- a/evaluate_auc.py +++ b/evaluate_auc.py @@ -2,12 +2,12 @@ import scipy.stats import scipy import warnings import torch -from model import DelphiConfig, Delphi +from models import TimeAwareGPT2 from tqdm import tqdm import pandas as pd import numpy as np import argparse -from utils import get_batch, get_p2i +from utils import load_model, get_batch, PatientEventDataset from pathlib import Path @@ -350,7 +350,7 @@ def evaluate_auc_pipeline( total=d100k[0].shape[0] // batch_size + 1, ): dd = [x.to(device) for x in dd] - outputs = model(*dd)[0].cpu().detach().numpy() + outputs = model(dd[0], dd[1]).cpu().detach().numpy() # Keep only the columns corresponding to the current disease chunk p100k.append(outputs[:, :, diseases_chunk].astype("float16")) # enough to store logits, but not rates p100k = np.vstack(p100k) @@ -422,13 +422,6 @@ def evaluate_auc_pipeline( def main(): parser = argparse.ArgumentParser(description="Evaluate AUC") - parser.add_argument("--input_path", type=str, help="Path to the dataset") - parser.add_argument("--output_path", type=str, help="Path to the output") - parser.add_argument("--model_ckpt_path", type=str, help="Path to the model weights") - parser.add_argument("--no_event_token_rate", type=int, help="No event token rate") - parser.add_argument( - "--health_token_replacement_prob", default=0.0, type=float, help="Health token replacement probability" - ) parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation") parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples") # Optional filtering/chunking parameters: @@ -436,10 +429,7 @@ def main(): parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases") args = parser.parse_args() - input_path = args.input_path - output_path = args.output_path - no_event_token_rate = args.no_event_token_rate - health_token_replacement_prob = args.health_token_replacement_prob + output_path = './' dataset_subset_size = args.dataset_subset_size # Create output folder if it doesn't exist. @@ -449,35 +439,26 @@ def main(): seed = 1337 # Load model checkpoint and initialize model. - ckpt_path = args.model_ckpt_path - checkpoint = torch.load(ckpt_path, map_location=device) - conf = DelphiConfig(**checkpoint["model_args"]) - model = Delphi(conf) - state_dict = checkpoint["model"] - model.load_state_dict(state_dict) + model = load_model('config_n_embd_256_n_layer_16_n_head_16.json', + 'best_model_n_embd_256_n_layer_16_n_head_16.pt', + 1270) model.eval() model = model.to(device) # Load training and validation data. - val = np.fromfile(f"{input_path}/val.bin", dtype=np.uint32).reshape(-1, 3).astype(np.int64) - val_p2i = get_p2i(val) + + val_data_path = 'ukb_real_val.bin' + + val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + block_length = 128 + val_dataset = PatientEventDataset(val_data_arr, block_length) if dataset_subset_size == -1: - dataset_subset_size = len(val_p2i) + dataset_subset_size = len(val_dataset) # Get a subset batch for evaluation. - d100k = get_batch( - range(dataset_subset_size), - val, - val_p2i, - select="left", - block_size=80, - device=device, - padding="random", - no_event_token_rate=no_event_token_rate, - health_token_replacement_prob=health_token_replacement_prob, - ) + d100k = get_batch(val_dataset, slice(dataset_subset_size)) # Load labels (external) to be passed in. delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv") diff --git a/evaluate_models.ipynb b/evaluate_models.ipynb index 2a3845a..1b18ed9 100644 --- a/evaluate_models.ipynb +++ b/evaluate_models.ipynb @@ -7,7 +7,6 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", "import torch\n", "from models import TimeAwareGPT2\n", "from utils import load_model\n", @@ -166,8 +165,8 @@ "output_type": "stream", "text": [ "Using device: cpu\n", - "Model config: {'n_layer': 12, 'n_embd': 120, 'n_head': 12, 'max_epoch': 200, 'batch_size': 128, 'lr_initial': 0.0006, 'lr_final': 6e-05, 'weight_decay': 0.2, 'warmup_epochs': 10, 'early_stopping_patience': 10, 'pdrop': 0.0, 'token_pdrop': 0.0, 'betas': [0.9, 0.99]}\n", - "Model loaded from best_model_n_embd_120_n_layer_12_n_head_12.pt with 2.40M parameters.\n" + "Model config: {'n_layer': 16, 'n_embd': 256, 'n_head': 16, 'max_epoch': 200, 'batch_size': 128, 'lr_initial': 0.0006, 'lr_final': 6e-05, 'weight_decay': 0.2, 'warmup_epochs': 10, 'early_stopping_patience': 10, 'pdrop': 0.0, 'token_pdrop': 0.0, 'betas': [0.9, 0.99]}\n", + "Model loaded from best_model_n_embd_256_n_layer_16_n_head_16.pt with 13.29M parameters.\n" ] } ], @@ -178,8 +177,8 @@ "torch.cuda.manual_seed(seed)\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(f'Using device: {device}')\n", - "model = load_model('config_n_embd_120_n_layer_12_n_head_12.json',\n", - " 'best_model_n_embd_120_n_layer_12_n_head_12.pt', \n", + "model = load_model('config_n_embd_256_n_layer_16_n_head_16.json',\n", + " 'best_model_n_embd_256_n_layer_16_n_head_16.pt', \n", " 1270)\n", "model.eval()\n", "model = model.to(device)" @@ -220,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 5, "id": "d7f2e4a1", "metadata": {}, "outputs": [ @@ -250,36 +249,29 @@ "42.0: No event\n", "=====================\n", "Generated trajectory:\n", - "44.9: E03 Other hypothyroidism\n", - "45.2: M19 Other arthrosis\n", - "45.7: H60 Otitis externa\n", - "47.1: H10 Conjunctivitis\n", - "47.6: H66 Suppurative and unspecified otitis media\n", - "48.4: M25 Other joint disorders, not elsewhere classified\n", - "49.8: M79 Other soft tissue disorders, not elsewhere classified\n", - "50.9: M65 Synovitis and tenosynovitis\n", - "51.2: D50 Iron deficiency anaemia\n", - "51.9: K57 Diverticular disease of intestine\n", - "52.6: K63 Other diseases of intestine\n", - "52.8: K66 Other disorders of peritoneum\n", - "52.9: A09 Diarrhoea and gastro-enteritis of presumed infectious origin\n", - "53.8: K20 Oesophagitis\n", - "53.8: K29 Gastritis and duodenitis\n", - "53.9: L30 Other dermatitis\n", - "54.5: M54 Dorsalgia\n", - "54.8: M15 Polyarthrosis\n", - "56.1: N30 Cystitis\n", - "56.3: F07 Personality and behavioural disorders due to brain disease, damage and dysfunction\n", - "56.4: K80 Cholelithiasis\n", - "56.4: M13 Other arthritis\n", - "56.6: B37 Candidiasis\n", - "58.0: E11 Non-insulin-dependent diabetes mellitus\n", - "58.0: K59 Other functional intestinal disorders\n", - "60.5: L03 Cellulitis\n", - "60.8: K52 Other non-infective gastro-enteritis and colitis\n", - "61.4: K31 Other diseases of stomach and duodenum\n", - "61.8: K74 Fibrosis and cirrhosis of liver\n", - "61.8: Death\n" + "43.9: B95 Streptococcus and staphylococcus as the cause of diseases classified to other chapters\n", + "44.6: J30 Vasomotor and allergic rhinitis\n", + "44.8: L03 Cellulitis\n", + "46.6: L30 Other dermatitis\n", + "46.8: K56 Paralytic ileus and intestinal obstruction without hernia\n", + "47.8: K76 Other diseases of liver\n", + "48.8: N20 Calculus of kidney and ureter\n", + "48.8: L24 Irritant contact dermatitis\n", + "49.5: I10 Essential primary hypertension\n", + "49.5: K59 Other functional intestinal disorders\n", + "50.6: B96 Other bacterial agents as the cause of diseases classified to other chapters\n", + "51.4: E14 Unspecified diabetes mellitus\n", + "51.6: E55 Vitamin d deficiency\n", + "51.7: A41 Other septicaemia\n", + "52.0: F41 Other anxiety disorders\n", + "52.1: F32 Depressive episode\n", + "52.3: M19 Other arthrosis\n", + "53.5: G96 Other disorders of central nervous system\n", + "53.7: E87 Other disorders of fluid, electrolyte and acid-base balance\n", + "53.8: J18 Pneumonia, organism unspecified\n", + "53.8: E66 Obesity\n", + "53.8: E11 Non-insulin-dependent diabetes mellitus\n", + "54.4: Death\n" ] } ], @@ -310,10 +302,106 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "0b937d75", "metadata": {}, "outputs": [], + "source": [ + "from utils import PatientEventDataset, get_batch\n", + "train_data_path = 'ukb_real_train.bin'\n", + "val_data_path = 'ukb_real_val.bin'\n", + "train_data_arr = np.memmap(train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)\n", + "val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)\n", + "block_length = 128\n", + "train_dataset = PatientEventDataset(train_data_arr, block_length)\n", + "val_dataset = PatientEventDataset(val_data_arr, block_length)\n", + "dataset_subset_size = 2048" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6a3bb2dc", + "metadata": {}, + "outputs": [], + "source": [ + "input_events, input_times, target_events, target_times = get_batch(val_dataset, slice(256))\n", + "with torch.no_grad():\n", + " p = model(input_events.to(device), input_times.to(device))\n", + " p = p.cpu().detach().numpy().squeeze()\n", + "t = (target_times-input_times).cpu().numpy().squeeze()\n", + "target_events = target_events.cpu().numpy().squeeze()\n", + "ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]\n", + "mask = ~np.isin(target_events, ignored_token_ids)\n", + "p = p[mask]\n", + "t = t[mask]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1c32bd6e", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from scipy.special import logsumexp\n", + "\n", + "# Calculate expected waiting times from model predictions using competing exponentials theory\n", + "# In Delphi's framework, each possible event has an exponential distribution with rate λᵢ = exp(logits[i])\n", + "# The expected time until any event occurs is 1/sum(λᵢ) = 1/exp(logsumexp(logits))\n", + "# logsumexp provides numerical stability vs. calculating exp(logits) directly\n", + "\n", + "# Let's see how the predicted waiting times compare to the observed waiting times\n", + "\n", + "plt.figure(figsize=(4, 4))\n", + "# Calculate expected time to next token (inverse of hazard rate)\n", + "expected_t = 1/np.exp(logsumexp(p, axis=-1))\n", + "\n", + "# Define bin width for logarithmic binning\n", + "delta_log_t = 0.1\n", + "log_range = np.arange(1.75, 4, delta_log_t)\n", + "\n", + "# Calculate average observed time for each logarithmic bin\n", + "observed_t = []\n", + "for i in log_range:\n", + " # Create mask for current bin and valid times\n", + " bin_mask = (expected_t > 10**i) & (expected_t <= 10**(i+delta_log_t)) & (t > 0)\n", + " # Calculate mean for this bin\n", + " bin_mean = t[bin_mask].mean() if bin_mask.sum() > 0 else np.nan\n", + " observed_t.append(bin_mean)\n", + "plt.axes().set_aspect('equal')\n", + "plt.scatter(expected_t, t+0.5, marker=\".\", c='lightgrey', rasterized=True)\n", + "plt.xlabel('Expected days to next token')\n", + "plt.ylabel('Observed days to next token')\n", + "plt.plot(10**(np.arange(1.75,4,delta_log_t)+delta_log_t/2.),observed_t, label='average')\n", + "plt.yscale('log')\n", + "plt.xscale('log')\n", + "plt.legend()\n", + "plt.xlim(1,2e3)\n", + "plt.ylim(1,2e3)\n", + "plt.plot([0,1],[0,1], transform = plt.gca().transAxes, c='k' , ls=(0, (5, 5)), linewidth=0.7)\n", + "\n", + "plt.gca().tick_params(length=1.15, width=0.3, labelsize=8, grid_alpha=1, grid_linewidth=0.45, grid_linestyle=':')\n", + "plt.gca().tick_params(length=1.15, width=0.3, labelsize=8, grid_alpha=0.0, grid_linewidth=0.35, which='minor')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e5e3c8f", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/utils.py b/utils.py index bab1538..ef25e81 100644 --- a/utils.py +++ b/utils.py @@ -42,17 +42,22 @@ class PatientEventDataset(torch.utils.data.Dataset): """ return len(self.patient_ids) - def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, idx): """ - Retrieves, processes, and returns a single patient's event sequence. + Retrieves, processes, and returns a single patient's event sequence, + or a list of sequences if a slice is provided. Args: - idx (int): The index of the patient to retrieve. + idx (int or slice): The index or slice of the patient(s) to retrieve. Returns: - A tuple of two torch.long tensors: (event_sequence, time_sequence), - both of shape (block_length,). + If idx is an int, a tuple of two torch.long tensors: + (event_sequence, time_sequence), both of shape (block_length,). + If idx is a slice, a list of such tuples. """ + if isinstance(idx, slice): + return [self[i] for i in range(*idx.indices(len(self)))] + # 1. Retrieve and Sort patient_id = self.patient_ids[idx] records = sorted(self.patient_events[patient_id], key=lambda x: x[0]) @@ -150,3 +155,35 @@ def load_model(config_path, model_path, vocab_size, device='cpu'): print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.") return model + + +def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Retrieves a batch of data from a PatientEventDataset and prepares it for model training. + + Args: + dataset (PatientEventDataset): The dataset to retrieve data from. + batch_slice (slice): The slice defining the batch of patients to retrieve. + ignore_tokens (list, optional): A list of token IDs to be ignored in the target events. + These tokens will be replaced with -100. Defaults to None. + + Returns: + A tuple containing four tensors: + - input_events: (batch_size, sequence_length - 1) + - input_tims: (batch_size, sequence_length - 1) + - target_events: (batch_size, sequence_length - 1) + - target_times: (batch_size, sequence_length - 1) + """ + batch_data = dataset[batch_slice] + + input_events = [item[0][:-1] for item in batch_data] + input_tims = [item[1][:-1] for item in batch_data] + target_events = [item[0][1:] for item in batch_data] + target_times = [item[1][1:] for item in batch_data] + + input_events = torch.stack(input_events) + input_tims = torch.stack(input_tims) + target_events = torch.stack(target_events) + target_times = torch.stack(target_times) + + return input_events, input_tims, target_events, target_times \ No newline at end of file