This commit is contained in:
2025-10-20 16:22:15 +08:00
parent 88cccdad2e
commit ddb7dbfc67

View File

@@ -432,6 +432,7 @@ def evaluate_auc_pipeline(
def main():
parser = argparse.ArgumentParser(description="Evaluate AUC")
parser.add_argument("--model_name", type=str, default="n_embd_256_n_layer_16_n_head_16", help="Model checkpoint name")
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
# Optional filtering/chunking parameters:
@@ -440,7 +441,8 @@ def main():
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run")
args = parser.parse_args()
output_path = './'
model_name = args.model_name
output_path = f'auc_evaluation_{model_name}'
dataset_subset_size = args.dataset_subset_size
# Create output folder if it doesn't exist.
@@ -450,8 +452,8 @@ def main():
seed = 1337
# Load model checkpoint and initialize model.
model = load_model('config_n_embd_256_n_layer_16_n_head_16.json',
'best_model_n_embd_256_n_layer_16_n_head_16.pt',
model = load_model(f'config_{model_name}.json',
f'best_model_{model_name}.pt',
1270)
model.eval()
model = model.to(device)