update evaluate

This commit is contained in:
2025-10-20 13:47:50 +08:00
parent 1c9e2a2fb3
commit 8f44018bae
3 changed files with 182 additions and 76 deletions

View File

@@ -2,12 +2,12 @@ import scipy.stats
import scipy
import warnings
import torch
from model import DelphiConfig, Delphi
from models import TimeAwareGPT2
from tqdm import tqdm
import pandas as pd
import numpy as np
import argparse
from utils import get_batch, get_p2i
from utils import load_model, get_batch, PatientEventDataset
from pathlib import Path
@@ -350,7 +350,7 @@ def evaluate_auc_pipeline(
total=d100k[0].shape[0] // batch_size + 1,
):
dd = [x.to(device) for x in dd]
outputs = model(*dd)[0].cpu().detach().numpy()
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
# Keep only the columns corresponding to the current disease chunk
p100k.append(outputs[:, :, diseases_chunk].astype("float16")) # enough to store logits, but not rates
p100k = np.vstack(p100k)
@@ -422,13 +422,6 @@ def evaluate_auc_pipeline(
def main():
parser = argparse.ArgumentParser(description="Evaluate AUC")
parser.add_argument("--input_path", type=str, help="Path to the dataset")
parser.add_argument("--output_path", type=str, help="Path to the output")
parser.add_argument("--model_ckpt_path", type=str, help="Path to the model weights")
parser.add_argument("--no_event_token_rate", type=int, help="No event token rate")
parser.add_argument(
"--health_token_replacement_prob", default=0.0, type=float, help="Health token replacement probability"
)
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
# Optional filtering/chunking parameters:
@@ -436,10 +429,7 @@ def main():
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
args = parser.parse_args()
input_path = args.input_path
output_path = args.output_path
no_event_token_rate = args.no_event_token_rate
health_token_replacement_prob = args.health_token_replacement_prob
output_path = './'
dataset_subset_size = args.dataset_subset_size
# Create output folder if it doesn't exist.
@@ -449,35 +439,26 @@ def main():
seed = 1337
# Load model checkpoint and initialize model.
ckpt_path = args.model_ckpt_path
checkpoint = torch.load(ckpt_path, map_location=device)
conf = DelphiConfig(**checkpoint["model_args"])
model = Delphi(conf)
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
model = load_model('config_n_embd_256_n_layer_16_n_head_16.json',
'best_model_n_embd_256_n_layer_16_n_head_16.pt',
1270)
model.eval()
model = model.to(device)
# Load training and validation data.
val = np.fromfile(f"{input_path}/val.bin", dtype=np.uint32).reshape(-1, 3).astype(np.int64)
val_p2i = get_p2i(val)
val_data_path = 'ukb_real_val.bin'
val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
block_length = 128
val_dataset = PatientEventDataset(val_data_arr, block_length)
if dataset_subset_size == -1:
dataset_subset_size = len(val_p2i)
dataset_subset_size = len(val_dataset)
# Get a subset batch for evaluation.
d100k = get_batch(
range(dataset_subset_size),
val,
val_p2i,
select="left",
block_size=80,
device=device,
padding="random",
no_event_token_rate=no_event_token_rate,
health_token_replacement_prob=health_token_replacement_prob,
)
d100k = get_batch(val_dataset, slice(dataset_subset_size))
# Load labels (external) to be passed in.
delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv")