Refactor training logic to improve early stopping mechanism and variable naming

This commit is contained in:
2026-01-08 00:07:15 +08:00
parent 811b2e1a46
commit 33ba7e6c1d

View File

@@ -294,7 +294,7 @@ class Trainer:
def train(self) -> None:
history = []
best_val_loss = float('inf')
best_val_score = float('inf')
patience_counter = 0
for epoch in range(1, self.cfg.max_epochs + 1):
self.model.train()
@@ -302,6 +302,7 @@ class Trainer:
running_reg = 0.0
pbar = tqdm(self.train_loader,
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
batch_count = 0
for batch in pbar:
(
event_seq,
@@ -445,7 +446,7 @@ class Trainer:
# Check for improvement
if val_nll < best_val_score:
best_val_score = val_nll
patient_counter = 0
patience_counter = 0
tqdm.write(" ✓ New best validation score. Saving checkpoint.")
torch.save({
@@ -456,21 +457,13 @@ class Trainer:
"optimizer_state_dict": self.optimizer.state_dict(),
}, self.best_path)
else:
patient_counter += 1
if epoch+1 >= self.cfg.min_epochs and patient_counter >= self.cfg.patient_epochs:
patience_counter += 1
if epoch >= self.cfg.warmup_epochs and patience_counter >= self.cfg.patience:
tqdm.write(
f"\n⚠ No improvement in validation score for {patient_counter} epochs. Early stopping.")
f"\n⚠ No improvement in validation score for {patience_counter} epochs. Early stopping.")
return
tqdm.write(
f" No improvement (patience: {patient_counter}/{self.cfg.patient_epochs})")
torch.save({
"epoch": epoch,
"global_step": self.global_step,
"model_state_dict": self.ema_model.state_dict(),
"criterion_state_dict": self.criterion.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
}, self.last_path)
f" No improvement (patience: {patience_counter}/{self.cfg.patience})")
tqdm.write("\n🎉 Training complete!")