From fe0304a96a60da37ad166225e1a9f18708e12127 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 17 Oct 2025 10:44:17 +0800 Subject: [PATCH] feat: Save model with params in name and log losses --- train.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index 1d9ed0c..275f12c 100644 --- a/train.py +++ b/train.py @@ -43,6 +43,9 @@ class TrainConfig: def main(): config = TrainConfig() + model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" + checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt" + # --- 1. Data Loading --- print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) @@ -160,9 +163,9 @@ def main(): val_losses_surv.append(avg_val_loss_surv) val_losses_total.append(total_val_loss) - print(f"Epoch {epoch+1} Summary: \n" - f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" - f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" + print(f"Epoch {epoch+1} Summary: \n" + f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" + f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" f" Learning Rate: {lr:.6f}") # --- Early Stopping Check --- @@ -170,7 +173,7 @@ def main(): best_val_loss = total_val_loss patience_counter = 0 print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") - torch.save(model.state_dict(), 'best_model_checkpoint.pt') + torch.save(model.state_dict(), checkpoint_filename) else: if epoch >= config.warmup_epochs: patience_counter += 1 @@ -183,12 +186,20 @@ def main(): # --- Save Best Model at the End --- if best_val_loss != float('inf'): print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") - model.load_state_dict(torch.load('best_model_checkpoint.pt')) - print("Saving final best model to best_model.pt") - torch.save(model.state_dict(), 'best_model.pt') + model.load_state_dict(torch.load(checkpoint_filename)) + print(f"Saving final best model to {model_filename}") + torch.save(model.state_dict(), model_filename) else: print("\nTraining finished. No best model to save as validation loss never improved.") + # --- Save losses to a txt file --- + losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt" + with open(losses_filename, 'w') as f: + f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n") + for i in range(len(train_losses_total)): + f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n") + print(f"\nLosses saved to {losses_filename}") + # --- Plot and Save Loss Curves --- num_epochs = len(train_losses_total) epochs = range(1, num_epochs + 1) @@ -231,4 +242,4 @@ def main(): if __name__ == '__main__': - main() + main() \ No newline at end of file