diff --git a/train.py b/train.py index 9fea9ae..7c46cb0 100644 --- a/train.py +++ b/train.py @@ -85,7 +85,8 @@ def main(): checkpoint_filename = f"best_model_checkpoint_{model_suffix}.pt" # --- 0. Save Configuration --- - config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json" + # Include model class in config filename for clarity/distinction across architectures + config_filename = f"config_{config.model_name}_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.json" config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')} with open(config_filename, 'w') as f: json.dump(config_dict, f, indent=4)