feat: Save model with params in name and log losses

This commit is contained in:
2025-10-17 10:44:17 +08:00
parent 7e8d8d307b
commit fe0304a96a

View File

@@ -43,6 +43,9 @@ class TrainConfig:
def main():
config = TrainConfig()
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
# --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
@@ -160,9 +163,9 @@ def main():
val_losses_surv.append(avg_val_loss_surv)
val_losses_total.append(total_val_loss)
print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
print(f"Epoch {epoch+1} Summary: \n"
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
f" Learning Rate: {lr:.6f}")
# --- Early Stopping Check ---
@@ -170,7 +173,7 @@ def main():
best_val_loss = total_val_loss
patience_counter = 0
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
torch.save(model.state_dict(), checkpoint_filename)
else:
if epoch >= config.warmup_epochs:
patience_counter += 1
@@ -183,12 +186,20 @@ def main():
# --- Save Best Model at the End ---
if best_val_loss != float('inf'):
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
print("Saving final best model to best_model.pt")
torch.save(model.state_dict(), 'best_model.pt')
model.load_state_dict(torch.load(checkpoint_filename))
print(f"Saving final best model to {model_filename}")
torch.save(model.state_dict(), model_filename)
else:
print("\nTraining finished. No best model to save as validation loss never improved.")
# --- Save losses to a txt file ---
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt"
with open(losses_filename, 'w') as f:
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
# --- Plot and Save Loss Curves ---
num_epochs = len(train_losses_total)
epochs = range(1, num_epochs + 1)
@@ -231,4 +242,4 @@ def main():
if __name__ == '__main__':
main()
main()