feat: Save model with params in name and log losses
This commit is contained in:
27
train.py
27
train.py
@@ -43,6 +43,9 @@ class TrainConfig:
|
||||
def main():
|
||||
config = TrainConfig()
|
||||
|
||||
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||
|
||||
# --- 1. Data Loading ---
|
||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||
@@ -160,9 +163,9 @@ def main():
|
||||
val_losses_surv.append(avg_val_loss_surv)
|
||||
val_losses_total.append(total_val_loss)
|
||||
|
||||
print(f"Epoch {epoch+1} Summary: \n"
|
||||
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
|
||||
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
|
||||
print(f"Epoch {epoch+1} Summary: \n"
|
||||
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
|
||||
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
|
||||
f" Learning Rate: {lr:.6f}")
|
||||
|
||||
# --- Early Stopping Check ---
|
||||
@@ -170,7 +173,7 @@ def main():
|
||||
best_val_loss = total_val_loss
|
||||
patience_counter = 0
|
||||
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
||||
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
|
||||
torch.save(model.state_dict(), checkpoint_filename)
|
||||
else:
|
||||
if epoch >= config.warmup_epochs:
|
||||
patience_counter += 1
|
||||
@@ -183,12 +186,20 @@ def main():
|
||||
# --- Save Best Model at the End ---
|
||||
if best_val_loss != float('inf'):
|
||||
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
||||
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
|
||||
print("Saving final best model to best_model.pt")
|
||||
torch.save(model.state_dict(), 'best_model.pt')
|
||||
model.load_state_dict(torch.load(checkpoint_filename))
|
||||
print(f"Saving final best model to {model_filename}")
|
||||
torch.save(model.state_dict(), model_filename)
|
||||
else:
|
||||
print("\nTraining finished. No best model to save as validation loss never improved.")
|
||||
|
||||
# --- Save losses to a txt file ---
|
||||
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt"
|
||||
with open(losses_filename, 'w') as f:
|
||||
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
|
||||
for i in range(len(train_losses_total)):
|
||||
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
|
||||
print(f"\nLosses saved to {losses_filename}")
|
||||
|
||||
# --- Plot and Save Loss Curves ---
|
||||
num_epochs = len(train_losses_total)
|
||||
epochs = range(1, num_epochs + 1)
|
||||
@@ -231,4 +242,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
Reference in New Issue
Block a user