update evaluate

This commit is contained in:
2025-10-20 13:47:50 +08:00
parent 1c9e2a2fb3
commit 8f44018bae
3 changed files with 182 additions and 76 deletions

View File

@@ -2,12 +2,12 @@ import scipy.stats
import scipy
import warnings
import torch
from model import DelphiConfig, Delphi
from models import TimeAwareGPT2
from tqdm import tqdm
import pandas as pd
import numpy as np
import argparse
from utils import get_batch, get_p2i
from utils import load_model, get_batch, PatientEventDataset
from pathlib import Path
@@ -350,7 +350,7 @@ def evaluate_auc_pipeline(
total=d100k[0].shape[0] // batch_size + 1,
):
dd = [x.to(device) for x in dd]
outputs = model(*dd)[0].cpu().detach().numpy()
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
# Keep only the columns corresponding to the current disease chunk
p100k.append(outputs[:, :, diseases_chunk].astype("float16")) # enough to store logits, but not rates
p100k = np.vstack(p100k)
@@ -422,13 +422,6 @@ def evaluate_auc_pipeline(
def main():
parser = argparse.ArgumentParser(description="Evaluate AUC")
parser.add_argument("--input_path", type=str, help="Path to the dataset")
parser.add_argument("--output_path", type=str, help="Path to the output")
parser.add_argument("--model_ckpt_path", type=str, help="Path to the model weights")
parser.add_argument("--no_event_token_rate", type=int, help="No event token rate")
parser.add_argument(
"--health_token_replacement_prob", default=0.0, type=float, help="Health token replacement probability"
)
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
# Optional filtering/chunking parameters:
@@ -436,10 +429,7 @@ def main():
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
args = parser.parse_args()
input_path = args.input_path
output_path = args.output_path
no_event_token_rate = args.no_event_token_rate
health_token_replacement_prob = args.health_token_replacement_prob
output_path = './'
dataset_subset_size = args.dataset_subset_size
# Create output folder if it doesn't exist.
@@ -449,35 +439,26 @@ def main():
seed = 1337
# Load model checkpoint and initialize model.
ckpt_path = args.model_ckpt_path
checkpoint = torch.load(ckpt_path, map_location=device)
conf = DelphiConfig(**checkpoint["model_args"])
model = Delphi(conf)
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
model = load_model('config_n_embd_256_n_layer_16_n_head_16.json',
'best_model_n_embd_256_n_layer_16_n_head_16.pt',
1270)
model.eval()
model = model.to(device)
# Load training and validation data.
val = np.fromfile(f"{input_path}/val.bin", dtype=np.uint32).reshape(-1, 3).astype(np.int64)
val_p2i = get_p2i(val)
val_data_path = 'ukb_real_val.bin'
val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
block_length = 128
val_dataset = PatientEventDataset(val_data_arr, block_length)
if dataset_subset_size == -1:
dataset_subset_size = len(val_p2i)
dataset_subset_size = len(val_dataset)
# Get a subset batch for evaluation.
d100k = get_batch(
range(dataset_subset_size),
val,
val_p2i,
select="left",
block_size=80,
device=device,
padding="random",
no_event_token_rate=no_event_token_rate,
health_token_replacement_prob=health_token_replacement_prob,
)
d100k = get_batch(val_dataset, slice(dataset_subset_size))
# Load labels (external) to be passed in.
delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv")

File diff suppressed because one or more lines are too long

View File

@@ -42,17 +42,22 @@ class PatientEventDataset(torch.utils.data.Dataset):
"""
return len(self.patient_ids)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
def __getitem__(self, idx):
"""
Retrieves, processes, and returns a single patient's event sequence.
Retrieves, processes, and returns a single patient's event sequence,
or a list of sequences if a slice is provided.
Args:
idx (int): The index of the patient to retrieve.
idx (int or slice): The index or slice of the patient(s) to retrieve.
Returns:
A tuple of two torch.long tensors: (event_sequence, time_sequence),
both of shape (block_length,).
If idx is an int, a tuple of two torch.long tensors:
(event_sequence, time_sequence), both of shape (block_length,).
If idx is a slice, a list of such tuples.
"""
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
# 1. Retrieve and Sort
patient_id = self.patient_ids[idx]
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
@@ -150,3 +155,35 @@ def load_model(config_path, model_path, vocab_size, device='cpu'):
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
return model
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
Args:
dataset (PatientEventDataset): The dataset to retrieve data from.
batch_slice (slice): The slice defining the batch of patients to retrieve.
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
These tokens will be replaced with -100. Defaults to None.
Returns:
A tuple containing four tensors:
- input_events: (batch_size, sequence_length - 1)
- input_tims: (batch_size, sequence_length - 1)
- target_events: (batch_size, sequence_length - 1)
- target_times: (batch_size, sequence_length - 1)
"""
batch_data = dataset[batch_slice]
input_events = [item[0][:-1] for item in batch_data]
input_tims = [item[1][:-1] for item in batch_data]
target_events = [item[0][1:] for item in batch_data]
target_times = [item[1][1:] for item in batch_data]
input_events = torch.stack(input_events)
input_tims = torch.stack(input_tims)
target_events = torch.stack(target_events)
target_times = torch.stack(target_times)
return input_events, input_tims, target_events, target_times