update evaluate
This commit is contained in:
@@ -2,12 +2,12 @@ import scipy.stats
|
||||
import scipy
|
||||
import warnings
|
||||
import torch
|
||||
from model import DelphiConfig, Delphi
|
||||
from models import TimeAwareGPT2
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import argparse
|
||||
from utils import get_batch, get_p2i
|
||||
from utils import load_model, get_batch, PatientEventDataset
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -350,7 +350,7 @@ def evaluate_auc_pipeline(
|
||||
total=d100k[0].shape[0] // batch_size + 1,
|
||||
):
|
||||
dd = [x.to(device) for x in dd]
|
||||
outputs = model(*dd)[0].cpu().detach().numpy()
|
||||
outputs = model(dd[0], dd[1]).cpu().detach().numpy()
|
||||
# Keep only the columns corresponding to the current disease chunk
|
||||
p100k.append(outputs[:, :, diseases_chunk].astype("float16")) # enough to store logits, but not rates
|
||||
p100k = np.vstack(p100k)
|
||||
@@ -422,13 +422,6 @@ def evaluate_auc_pipeline(
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Evaluate AUC")
|
||||
parser.add_argument("--input_path", type=str, help="Path to the dataset")
|
||||
parser.add_argument("--output_path", type=str, help="Path to the output")
|
||||
parser.add_argument("--model_ckpt_path", type=str, help="Path to the model weights")
|
||||
parser.add_argument("--no_event_token_rate", type=int, help="No event token rate")
|
||||
parser.add_argument(
|
||||
"--health_token_replacement_prob", default=0.0, type=float, help="Health token replacement probability"
|
||||
)
|
||||
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
|
||||
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
|
||||
# Optional filtering/chunking parameters:
|
||||
@@ -436,10 +429,7 @@ def main():
|
||||
parser.add_argument("--disease_chunk_size", type=int, default=200, help="Chunk size for processing diseases")
|
||||
args = parser.parse_args()
|
||||
|
||||
input_path = args.input_path
|
||||
output_path = args.output_path
|
||||
no_event_token_rate = args.no_event_token_rate
|
||||
health_token_replacement_prob = args.health_token_replacement_prob
|
||||
output_path = './'
|
||||
dataset_subset_size = args.dataset_subset_size
|
||||
|
||||
# Create output folder if it doesn't exist.
|
||||
@@ -449,35 +439,26 @@ def main():
|
||||
seed = 1337
|
||||
|
||||
# Load model checkpoint and initialize model.
|
||||
ckpt_path = args.model_ckpt_path
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
conf = DelphiConfig(**checkpoint["model_args"])
|
||||
model = Delphi(conf)
|
||||
state_dict = checkpoint["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
model = load_model('config_n_embd_256_n_layer_16_n_head_16.json',
|
||||
'best_model_n_embd_256_n_layer_16_n_head_16.pt',
|
||||
1270)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
# Load training and validation data.
|
||||
val = np.fromfile(f"{input_path}/val.bin", dtype=np.uint32).reshape(-1, 3).astype(np.int64)
|
||||
|
||||
val_p2i = get_p2i(val)
|
||||
|
||||
val_data_path = 'ukb_real_val.bin'
|
||||
|
||||
val_data_arr = np.memmap(val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
block_length = 128
|
||||
val_dataset = PatientEventDataset(val_data_arr, block_length)
|
||||
|
||||
if dataset_subset_size == -1:
|
||||
dataset_subset_size = len(val_p2i)
|
||||
dataset_subset_size = len(val_dataset)
|
||||
|
||||
# Get a subset batch for evaluation.
|
||||
d100k = get_batch(
|
||||
range(dataset_subset_size),
|
||||
val,
|
||||
val_p2i,
|
||||
select="left",
|
||||
block_size=80,
|
||||
device=device,
|
||||
padding="random",
|
||||
no_event_token_rate=no_event_token_rate,
|
||||
health_token_replacement_prob=health_token_replacement_prob,
|
||||
)
|
||||
d100k = get_batch(val_dataset, slice(dataset_subset_size))
|
||||
|
||||
# Load labels (external) to be passed in.
|
||||
delphi_labels = pd.read_csv("delphi_labels_chapters_colours_icd.csv")
|
||||
|
Reference in New Issue
Block a user