feat: Save model with params in name and log losses

This commit is contained in:
2025-10-17 10:44:17 +08:00
parent 7e8d8d307b
commit fe0304a96a

View File

@@ -43,6 +43,9 @@ class TrainConfig:
def main(): def main():
config = TrainConfig() config = TrainConfig()
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
# --- 1. Data Loading --- # --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
@@ -170,7 +173,7 @@ def main():
best_val_loss = total_val_loss best_val_loss = total_val_loss
patience_counter = 0 patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...") print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
torch.save(model.state_dict(), 'best_model_checkpoint.pt') torch.save(model.state_dict(), checkpoint_filename)
else: else:
if epoch >= config.warmup_epochs: if epoch >= config.warmup_epochs:
patience_counter += 1 patience_counter += 1
@@ -183,12 +186,20 @@ def main():
# --- Save Best Model at the End --- # --- Save Best Model at the End ---
if best_val_loss != float('inf'): if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.") print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
model.load_state_dict(torch.load('best_model_checkpoint.pt')) model.load_state_dict(torch.load(checkpoint_filename))
print("Saving final best model to best_model.pt") print(f"Saving final best model to {model_filename}")
torch.save(model.state_dict(), 'best_model.pt') torch.save(model.state_dict(), model_filename)
else: else:
print("\nTraining finished. No best model to save as validation loss never improved.") print("\nTraining finished. No best model to save as validation loss never improved.")
# --- Save losses to a txt file ---
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt"
with open(losses_filename, 'w') as f:
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
# --- Plot and Save Loss Curves --- # --- Plot and Save Loss Curves ---
num_epochs = len(train_losses_total) num_epochs = len(train_losses_total)
epochs = range(1, num_epochs + 1) epochs = range(1, num_epochs + 1)