update evaluate

This commit is contained in:
2025-10-20 13:47:50 +08:00
parent 1c9e2a2fb3
commit 8f44018bae
3 changed files with 182 additions and 76 deletions

View File

@@ -2,12 +2,12 @@ import scipy.stats
import scipy import scipy
import warnings import warnings
import torch import torch
from model import DelphiConfig, Delphi from models import TimeAwareGPT2
from tqdm import tqdm from tqdm import tqdm
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import argparse import argparse
from utils import get_batch, get_p2i from utils import load_model, get_batch, PatientEventDataset
from pathlib import Path from pathlib import Path
@@ -350,7 +350,7 @@ def evaluate_auc_pipeline(
total=d100k[0].shape[0] // batch_size + 1, total=d100k[0].shape[0] // batch_size + 1,
): ):
dd = [x.to(device) for x in dd] dd = [x.to(device) for x in dd]
outputs = model(*dd)[0].cpu().detach().numpy() outputs = model(dd[0], dd[1]).cpu().detach().numpy()
# Keep only the columns corresponding to the current disease chunk # Keep only the columns corresponding to the current disease chunk
p100k.append(outputs[:, :, diseases_chunk].astype("float16")) # enough to store logits, but not rates p100k.append(outputs[:, :, diseases_chunk].astype("float16")) # enough to store logits, but not rates
p100k = np.vstack(p100k) p100k = np.vstack(p100k)
@@ -422,13 +422,6 @@ def evaluate_auc_pipeline(
def main(): def main():
parser = argparse.ArgumentParser(description="Evaluate AUC") parser = argparse.ArgumentParser(description="Evaluate AUC")
parser.add_argument("--input_path", type=str, help="Path to the dataset")
parser.add_argument("--output_path", type=str, help="Path to the output")
parser.add_argument("--model_ckpt_path", type=str, help="Path to the model weights")
parser.add_argument("--no_event_token_rate", type=int, help="No event token rate")
parser.add_argument(
"--health_token_replacement_prob", default=0.0, type=float, help="Health token replacement probability"
)
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation") parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples") parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
# Optional filtering/chunking parameters: # Optional filtering/chunking parameters:
@@ -436,10 +429,7 @@ def main():
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases") parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
args = parser.parse_args() args = parser.parse_args()
input_path = args.input_path output_path = './'
output_path = args.output_path
no_event_token_rate = args.no_event_token_rate
health_token_replacement_prob = args.health_token_replacement_prob
dataset_subset_size = args.dataset_subset_size dataset_subset_size = args.dataset_subset_size
# Create output folder if it doesn't exist. # Create output folder if it doesn't exist.
@@ -449,35 +439,26 @@ def main():
seed = 1337 seed = 1337
# Load model checkpoint and initialize model. # Load model checkpoint and initialize model.
ckpt_path = args.model_ckpt_path model = load_model('config_n_embd_256_n_layer_16_n_head_16.json',
checkpoint = torch.load(ckpt_path, map_location=device) 'best_model_n_embd_256_n_layer_16_n_head_16.pt',
conf = DelphiConfig(**checkpoint["model_args"]) 1270)
model = Delphi(conf)
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
model.eval() model.eval()
model = model.to(device) model = model.to(device)
# Load training and validation data. # Load training and validation data.
val = np.fromfile(f"{input_path}/val.bin", dtype=np.uint32).reshape(-1, 3).astype(np.int64)
val_p2i = get_p2i(val)
val_data_path = 'ukb_real_val.bin'
val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
block_length = 128
val_dataset = PatientEventDataset(val_data_arr, block_length)
if dataset_subset_size == -1: if dataset_subset_size == -1:
dataset_subset_size = len(val_p2i) dataset_subset_size = len(val_dataset)
# Get a subset batch for evaluation. # Get a subset batch for evaluation.
d100k = get_batch( d100k = get_batch(val_dataset, slice(dataset_subset_size))
range(dataset_subset_size),
val,
val_p2i,
select="left",
block_size=80,
device=device,
padding="random",
no_event_token_rate=no_event_token_rate,
health_token_replacement_prob=health_token_replacement_prob,
)
# Load labels (external) to be passed in. # Load labels (external) to be passed in.
delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv") delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv")

File diff suppressed because one or more lines are too long

View File

@@ -42,17 +42,22 @@ class PatientEventDataset(torch.utils.data.Dataset):
""" """
return len(self.patient_ids) return len(self.patient_ids)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: def __getitem__(self, idx):
""" """
Retrieves, processes, and returns a single patient's event sequence. Retrieves, processes, and returns a single patient's event sequence,
or a list of sequences if a slice is provided.
Args: Args:
idx (int): The index of the patient to retrieve. idx (int or slice): The index or slice of the patient(s) to retrieve.
Returns: Returns:
A tuple of two torch.long tensors: (event_sequence, time_sequence), If idx is an int, a tuple of two torch.long tensors:
both of shape (block_length,). (event_sequence, time_sequence), both of shape (block_length,).
If idx is a slice, a list of such tuples.
""" """
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
# 1. Retrieve and Sort # 1. Retrieve and Sort
patient_id = self.patient_ids[idx] patient_id = self.patient_ids[idx]
records = sorted(self.patient_events[patient_id], key=lambda x: x[0]) records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
@@ -150,3 +155,35 @@ def load_model(config_path, model_path, vocab_size, device='cpu'):
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.") print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
return model return model
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
Args:
dataset (PatientEventDataset): The dataset to retrieve data from.
batch_slice (slice): The slice defining the batch of patients to retrieve.
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
These tokens will be replaced with -100. Defaults to None.
Returns:
A tuple containing four tensors:
- input_events: (batch_size, sequence_length - 1)
- input_tims: (batch_size, sequence_length - 1)
- target_events: (batch_size, sequence_length - 1)
- target_times: (batch_size, sequence_length - 1)
"""
batch_data = dataset[batch_slice]
input_events = [item[0][:-1] for item in batch_data]
input_tims = [item[1][:-1] for item in batch_data]
target_events = [item[0][1:] for item in batch_data]
target_times = [item[1][1:] for item in batch_data]
input_events = torch.stack(input_events)
input_tims = torch.stack(input_tims)
target_events = torch.stack(target_events)
target_times = torch.stack(target_times)
return input_events, input_tims, target_events, target_times