diff --git a/train.py b/train.py index af147ea..9d54ac8 100644 --- a/train.py +++ b/train.py @@ -294,7 +294,7 @@ class Trainer: def train(self) -> None: history = [] - best_val_loss = float('inf') + best_val_score = float('inf') patience_counter = 0 for epoch in range(1, self.cfg.max_epochs + 1): self.model.train() @@ -302,6 +302,7 @@ class Trainer: running_reg = 0.0 pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100) + batch_count = 0 for batch in pbar: ( event_seq, @@ -445,7 +446,7 @@ class Trainer: # Check for improvement if val_nll < best_val_score: best_val_score = val_nll - patient_counter = 0 + patience_counter = 0 tqdm.write(" ✓ New best validation score. Saving checkpoint.") torch.save({ @@ -456,21 +457,13 @@ class Trainer: "optimizer_state_dict": self.optimizer.state_dict(), }, self.best_path) else: - patient_counter += 1 - if epoch+1 >= self.cfg.min_epochs and patient_counter >= self.cfg.patient_epochs: + patience_counter += 1 + if epoch >= self.cfg.warmup_epochs and patience_counter >= self.cfg.patience: tqdm.write( - f"\n⚠ No improvement in validation score for {patient_counter} epochs. Early stopping.") + f"\n⚠ No improvement in validation score for {patience_counter} epochs. Early stopping.") return tqdm.write( - f" No improvement (patience: {patient_counter}/{self.cfg.patient_epochs})") - - torch.save({ - "epoch": epoch, - "global_step": self.global_step, - "model_state_dict": self.ema_model.state_dict(), - "criterion_state_dict": self.criterion.state_dict(), - "optimizer_state_dict": self.optimizer.state_dict(), - }, self.last_path) + f" No improvement (patience: {patience_counter}/{self.cfg.patience})") tqdm.write("\n🎉 Training complete!")