update
This commit is contained in:
@@ -432,6 +432,7 @@ def evaluate_auc_pipeline(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Evaluate AUC")
|
parser = argparse.ArgumentParser(description="Evaluate AUC")
|
||||||
|
parser.add_argument("--model_name", type=str, default="n_embd_256_n_layer_16_n_head_16", help="Model checkpoint name")
|
||||||
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
|
parser.add_argument("--dataset_subset_size", type=int, default=-1, help="Dataset subset size for evaluation")
|
||||||
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
|
parser.add_argument("--n_bootstrap", type=int, default=1, help="Number of bootstrap samples")
|
||||||
# Optional filtering/chunking parameters:
|
# Optional filtering/chunking parameters:
|
||||||
@@ -440,7 +441,8 @@ def main():
|
|||||||
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run")
|
parser.add_argument("--n_jobs", type=int, default=-1, help="Number of parallel jobs to run")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
output_path = './'
|
model_name = args.model_name
|
||||||
|
output_path = f'auc_evaluation_{model_name}'
|
||||||
dataset_subset_size = args.dataset_subset_size
|
dataset_subset_size = args.dataset_subset_size
|
||||||
|
|
||||||
# Create output folder if it doesn't exist.
|
# Create output folder if it doesn't exist.
|
||||||
@@ -450,8 +452,8 @@ def main():
|
|||||||
seed = 1337
|
seed = 1337
|
||||||
|
|
||||||
# Load model checkpoint and initialize model.
|
# Load model checkpoint and initialize model.
|
||||||
model = load_model('config_n_embd_256_n_layer_16_n_head_16.json',
|
model = load_model(f'config_{model_name}.json',
|
||||||
'best_model_n_embd_256_n_layer_16_n_head_16.pt',
|
f'best_model_{model_name}.pt',
|
||||||
1270)
|
1270)
|
||||||
model.eval()
|
model.eval()
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
Reference in New Issue
Block a user