diff --git a/evaluate_auc.py b/evaluate_auc.py index b3a6c87..932eef3 100644 --- a/evaluate_auc.py +++ b/evaluate_auc.py @@ -1,8 +1,4 @@ -import scipy.stats -import scipy -import warnings import torch -from models import TimeAwareGPT2 from tqdm import tqdm import pandas as pd import numpy as np @@ -450,13 +446,15 @@ def main(): # Create output folder if it doesn't exist. Path(output_path).mkdir(exist_ok=True, parents=True) - device = "cuda" + device = "cuda" if torch.cuda.is_available() else "cpu" seed = 1337 # Load model checkpoint and initialize model. - model = load_model(f'config_{model_name}.json', - f'best_model_{model_name}.pt', - 1270) + model = load_model( + config_path=f'config_{model_name}.json', + model_path=f'best_model_{model_name}.pt', + device=device, + ) model.eval() model = model.to(device)