evaluate_auc: use new utils.load_model (infer vocab, model variants) and dynamic device; remove unused imports
This commit is contained in:
@@ -1,8 +1,4 @@
|
||||
import scipy.stats
|
||||
import scipy
|
||||
import warnings
|
||||
import torch
|
||||
from models import TimeAwareGPT2
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -450,13 +446,15 @@ def main():
|
||||
# Create output folder if it doesn't exist.
|
||||
Path(output_path).mkdir(exist_ok=True, parents=True)
|
||||
|
||||
device = "cuda"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
seed = 1337
|
||||
|
||||
# Load model checkpoint and initialize model.
|
||||
model = load_model(f'config_{model_name}.json',
|
||||
f'best_model_{model_name}.pt',
|
||||
1270)
|
||||
model = load_model(
|
||||
config_path=f'config_{model_name}.json',
|
||||
model_path=f'best_model_{model_name}.pt',
|
||||
device=device,
|
||||
)
|
||||
model.eval()
|
||||
model = model.to(device)
|
||||
|
||||
|
Reference in New Issue
Block a user