evaluate_auc: use new utils.load_model (infer vocab, model variants) and dynamic device; remove unused imports
This commit is contained in:
@@ -1,8 +1,4 @@
|
|||||||
import scipy.stats
|
|
||||||
import scipy
|
|
||||||
import warnings
|
|
||||||
import torch
|
import torch
|
||||||
from models import TimeAwareGPT2
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -450,13 +446,15 @@ def main():
|
|||||||
# Create output folder if it doesn't exist.
|
# Create output folder if it doesn't exist.
|
||||||
Path(output_path).mkdir(exist_ok=True, parents=True)
|
Path(output_path).mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
seed = 1337
|
seed = 1337
|
||||||
|
|
||||||
# Load model checkpoint and initialize model.
|
# Load model checkpoint and initialize model.
|
||||||
model = load_model(f'config_{model_name}.json',
|
model = load_model(
|
||||||
f'best_model_{model_name}.pt',
|
config_path=f'config_{model_name}.json',
|
||||||
1270)
|
model_path=f'best_model_{model_name}.pt',
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
model.eval()
|
model.eval()
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user