evaluate_auc: use new utils.load_model (infer vocab, model variants) and dynamic device; remove unused imports

This commit is contained in:
2025-10-22 11:35:52 +08:00
parent dfdf64da9a
commit 92a5bd4a83

View File

@@ -1,8 +1,4 @@
import scipy.stats
import scipy
import warnings
import torch
from models import TimeAwareGPT2
from tqdm import tqdm
import pandas as pd
import numpy as np
@@ -450,13 +446,15 @@ def main():
# Create output folder if it doesn't exist.
Path(output_path).mkdir(exist_ok=True, parents=True)
device = "cuda"
device = "cuda" if torch.cuda.is_available() else "cpu"
seed = 1337
# Load model checkpoint and initialize model.
model = load_model(f'config_{model_name}.json',
f'best_model_{model_name}.pt',
1270)
model = load_model(
config_path=f'config_{model_name}.json',
model_path=f'best_model_{model_name}.pt',
device=device,
)
model.eval()
model = model.to(device)