update evaluate
This commit is contained in:
@@ -2,12 +2,12 @@ import scipy.stats
|
||||
import scipy
|
||||
import warnings
|
||||
import torch
|
||||
from model import DelphiConfig, Delphi
|
||||
from models import TimeAwareGPT2
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import argparse
|
||||
from utils import get_batch, get_p2i
|
||||
from utils import load_model, get_batch, PatientEventDataset
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -350,7 +350,7 @@ def evaluate_auc_pipeline(
|
||||
total=d100k[0].shape[0] // batch_size + 1,
|
||||
):
|
||||
dd = [x.to(device) for x in dd]
|
||||
outputs = model(*dd)[0].cpu().detach().numpy()
|
||||
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
|
||||
# Keep only the columns corresponding to the current disease chunk
|
||||
p100k.append(outputs[:, :, diseases_chunk].astype("float16")) # enough to store logits, but not rates
|
||||
p100k = np.vstack(p100k)
|
||||
@@ -422,13 +422,6 @@ def evaluate_auc_pipeline(
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate AUC")
|
||||
parser.add_argument("--input_path", type=str, help="Path to the dataset")
|
||||
parser.add_argument("--output_path", type=str, help="Path to the output")
|
||||
parser.add_argument("--model_ckpt_path", type=str, help="Path to the model weights")
|
||||
parser.add_argument("--no_event_token_rate", type=int, help="No event token rate")
|
||||
parser.add_argument(
|
||||
"--health_token_replacement_prob", default=0.0, type=float, help="Health token replacement probability"
|
||||
)
|
||||
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
|
||||
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
|
||||
# Optional filtering/chunking parameters:
|
||||
@@ -436,10 +429,7 @@ def main():
|
||||
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
|
||||
args = parser.parse_args()
|
||||
|
||||
input_path = args.input_path
|
||||
output_path = args.output_path
|
||||
no_event_token_rate = args.no_event_token_rate
|
||||
health_token_replacement_prob = args.health_token_replacement_prob
|
||||
output_path = './'
|
||||
dataset_subset_size = args.dataset_subset_size
|
||||
|
||||
# Create output folder if it doesn't exist.
|
||||
@@ -449,35 +439,26 @@ def main():
|
||||
seed = 1337
|
||||
|
||||
# Load model checkpoint and initialize model.
|
||||
ckpt_path = args.model_ckpt_path
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
conf = DelphiConfig(**checkpoint["model_args"])
|
||||
model = Delphi(conf)
|
||||
state_dict = checkpoint["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
model = load_model('config_n_embd_256_n_layer_16_n_head_16.json',
|
||||
'best_model_n_embd_256_n_layer_16_n_head_16.pt',
|
||||
1270)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
# Load training and validation data.
|
||||
val = np.fromfile(f"{input_path}/val.bin", dtype=np.uint32).reshape(-1, 3).astype(np.int64)
|
||||
|
||||
val_p2i = get_p2i(val)
|
||||
|
||||
val_data_path = 'ukb_real_val.bin'
|
||||
|
||||
val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
block_length = 128
|
||||
val_dataset = PatientEventDataset(val_data_arr, block_length)
|
||||
|
||||
if dataset_subset_size == -1:
|
||||
dataset_subset_size = len(val_p2i)
|
||||
dataset_subset_size = len(val_dataset)
|
||||
|
||||
# Get a subset batch for evaluation.
|
||||
d100k = get_batch(
|
||||
range(dataset_subset_size),
|
||||
val,
|
||||
val_p2i,
|
||||
select="left",
|
||||
block_size=80,
|
||||
device=device,
|
||||
padding="random",
|
||||
no_event_token_rate=no_event_token_rate,
|
||||
health_token_replacement_prob=health_token_replacement_prob,
|
||||
)
|
||||
d100k = get_batch(val_dataset, slice(dataset_subset_size))
|
||||
|
||||
# Load labels (external) to be passed in.
|
||||
delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv")
|
||||
|
File diff suppressed because one or more lines are too long
47
utils.py
47
utils.py
@@ -42,17 +42,22 @@ class PatientEventDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
return len(self.patient_ids)
|
||||
|
||||
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Retrieves, processes, and returns a single patient's event sequence.
|
||||
Retrieves, processes, and returns a single patient's event sequence,
|
||||
or a list of sequences if a slice is provided.
|
||||
|
||||
Args:
|
||||
idx (int): The index of the patient to retrieve.
|
||||
idx (int or slice): The index or slice of the patient(s) to retrieve.
|
||||
|
||||
Returns:
|
||||
A tuple of two torch.long tensors: (event_sequence, time_sequence),
|
||||
both of shape (block_length,).
|
||||
If idx is an int, a tuple of two torch.long tensors:
|
||||
(event_sequence, time_sequence), both of shape (block_length,).
|
||||
If idx is a slice, a list of such tuples.
|
||||
"""
|
||||
if isinstance(idx, slice):
|
||||
return [self[i] for i in range(*idx.indices(len(self)))]
|
||||
|
||||
# 1. Retrieve and Sort
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
|
||||
@@ -150,3 +155,35 @@ def load_model(config_path, model_path, vocab_size, device='cpu'):
|
||||
|
||||
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
|
||||
return model
|
||||
|
||||
|
||||
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
|
||||
|
||||
Args:
|
||||
dataset (PatientEventDataset): The dataset to retrieve data from.
|
||||
batch_slice (slice): The slice defining the batch of patients to retrieve.
|
||||
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
|
||||
These tokens will be replaced with -100. Defaults to None.
|
||||
|
||||
Returns:
|
||||
A tuple containing four tensors:
|
||||
- input_events: (batch_size, sequence_length - 1)
|
||||
- input_tims: (batch_size, sequence_length - 1)
|
||||
- target_events: (batch_size, sequence_length - 1)
|
||||
- target_times: (batch_size, sequence_length - 1)
|
||||
"""
|
||||
batch_data = dataset[batch_slice]
|
||||
|
||||
input_events = [item[0][:-1] for item in batch_data]
|
||||
input_tims = [item[1][:-1] for item in batch_data]
|
||||
target_events = [item[0][1:] for item in batch_data]
|
||||
target_times = [item[1][1:] for item in batch_data]
|
||||
|
||||
input_events = torch.stack(input_events)
|
||||
input_tims = torch.stack(input_tims)
|
||||
target_events = torch.stack(target_events)
|
||||
target_times = torch.stack(target_times)
|
||||
|
||||
return input_events, input_tims, target_events, target_times
|
Reference in New Issue
Block a user