diff --git a/evaluate_auc.py b/evaluate_auc.py index 7dfc6c1..e639613 100644 --- a/evaluate_auc.py +++ b/evaluate_auc.py @@ -432,6 +432,7 @@ def evaluate_auc_pipeline( def main(): parser = argparse.ArgumentParser(description="Evaluate AUC") + parser.add_argument("--model_name", type=str, default="n_embd_256_n_layer_16_n_head_16", help="Model checkpoint name") parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation") parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples") # Optional filtering/chunking parameters: @@ -440,7 +441,8 @@ def main(): parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run") args = parser.parse_args() - output_path = './' + model_name = args.model_name + output_path = f'auc_evaluation_{model_name}' dataset_subset_size = args.dataset_subset_size # Create output folder if it doesn't exist. @@ -450,8 +452,8 @@ def main(): seed = 1337 # Load model checkpoint and initialize model. - model = load_model('config_n_embd_256_n_layer_16_n_head_16.json', - 'best_model_n_embd_256_n_layer_16_n_head_16.pt', + model = load_model(f'config_{model_name}.json', + f'best_model_{model_name}.pt', 1270) model.eval() model = model.to(device)