feat: Save model with params in name and log losses
This commit is contained in:
27
train.py
27
train.py
@@ -43,6 +43,9 @@ class TrainConfig:
|
|||||||
def main():
|
def main():
|
||||||
config = TrainConfig()
|
config = TrainConfig()
|
||||||
|
|
||||||
|
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||||
|
checkpoint_filename = f"best_model_checkpoint_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.pt"
|
||||||
|
|
||||||
# --- 1. Data Loading ---
|
# --- 1. Data Loading ---
|
||||||
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
|
||||||
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
|
||||||
@@ -160,9 +163,9 @@ def main():
|
|||||||
val_losses_surv.append(avg_val_loss_surv)
|
val_losses_surv.append(avg_val_loss_surv)
|
||||||
val_losses_total.append(total_val_loss)
|
val_losses_total.append(total_val_loss)
|
||||||
|
|
||||||
print(f"Epoch {epoch+1} Summary: \n"
|
print(f"Epoch {epoch+1} Summary: \n"
|
||||||
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
|
f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n"
|
||||||
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
|
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n"
|
||||||
f" Learning Rate: {lr:.6f}")
|
f" Learning Rate: {lr:.6f}")
|
||||||
|
|
||||||
# --- Early Stopping Check ---
|
# --- Early Stopping Check ---
|
||||||
@@ -170,7 +173,7 @@ def main():
|
|||||||
best_val_loss = total_val_loss
|
best_val_loss = total_val_loss
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
print(f"Validation loss improved to {best_val_loss:.4f}. Saving checkpoint...")
|
||||||
torch.save(model.state_dict(), 'best_model_checkpoint.pt')
|
torch.save(model.state_dict(), checkpoint_filename)
|
||||||
else:
|
else:
|
||||||
if epoch >= config.warmup_epochs:
|
if epoch >= config.warmup_epochs:
|
||||||
patience_counter += 1
|
patience_counter += 1
|
||||||
@@ -183,12 +186,20 @@ def main():
|
|||||||
# --- Save Best Model at the End ---
|
# --- Save Best Model at the End ---
|
||||||
if best_val_loss != float('inf'):
|
if best_val_loss != float('inf'):
|
||||||
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
print(f"\nTraining finished. Loading best model from checkpoint with validation loss {best_val_loss:.4f}.")
|
||||||
model.load_state_dict(torch.load('best_model_checkpoint.pt'))
|
model.load_state_dict(torch.load(checkpoint_filename))
|
||||||
print("Saving final best model to best_model.pt")
|
print(f"Saving final best model to {model_filename}")
|
||||||
torch.save(model.state_dict(), 'best_model.pt')
|
torch.save(model.state_dict(), model_filename)
|
||||||
else:
|
else:
|
||||||
print("\nTraining finished. No best model to save as validation loss never improved.")
|
print("\nTraining finished. No best model to save as validation loss never improved.")
|
||||||
|
|
||||||
|
# --- Save losses to a txt file ---
|
||||||
|
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}.txt"
|
||||||
|
with open(losses_filename, 'w') as f:
|
||||||
|
f.write("epoch,train_loss_ce,train_loss_surv,train_loss_total,val_loss_ce,val_loss_surv,val_loss_total\n")
|
||||||
|
for i in range(len(train_losses_total)):
|
||||||
|
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]},{val_losses_ce[i]},{val_losses_surv[i]},{val_losses_total[i]}\n")
|
||||||
|
print(f"\nLosses saved to {losses_filename}")
|
||||||
|
|
||||||
# --- Plot and Save Loss Curves ---
|
# --- Plot and Save Loss Curves ---
|
||||||
num_epochs = len(train_losses_total)
|
num_epochs = len(train_losses_total)
|
||||||
epochs = range(1, num_epochs + 1)
|
epochs = range(1, num_epochs + 1)
|
||||||
@@ -231,4 +242,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
Reference in New Issue
Block a user