Refactor training logic to improve early stopping mechanism and variable naming

This commit is contained in:
2026-01-08 00:07:15 +08:00
parent 811b2e1a46
commit 33ba7e6c1d

View File

@@ -294,7 +294,7 @@ class Trainer:
def train(self) -> None: def train(self) -> None:
history = [] history = []
best_val_loss = float('inf') best_val_score = float('inf')
patience_counter = 0 patience_counter = 0
for epoch in range(1, self.cfg.max_epochs + 1): for epoch in range(1, self.cfg.max_epochs + 1):
self.model.train() self.model.train()
@@ -302,6 +302,7 @@ class Trainer:
running_reg = 0.0 running_reg = 0.0
pbar = tqdm(self.train_loader, pbar = tqdm(self.train_loader,
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100) desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
batch_count = 0
for batch in pbar: for batch in pbar:
( (
event_seq, event_seq,
@@ -445,7 +446,7 @@ class Trainer:
# Check for improvement # Check for improvement
if val_nll < best_val_score: if val_nll < best_val_score:
best_val_score = val_nll best_val_score = val_nll
patient_counter = 0 patience_counter = 0
tqdm.write(" ✓ New best validation score. Saving checkpoint.") tqdm.write(" ✓ New best validation score. Saving checkpoint.")
torch.save({ torch.save({
@@ -456,21 +457,13 @@ class Trainer:
"optimizer_state_dict": self.optimizer.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(),
}, self.best_path) }, self.best_path)
else: else:
patient_counter += 1 patience_counter += 1
if epoch+1 >= self.cfg.min_epochs and patient_counter >= self.cfg.patient_epochs: if epoch >= self.cfg.warmup_epochs and patience_counter >= self.cfg.patience:
tqdm.write( tqdm.write(
f"\n⚠ No improvement in validation score for {patient_counter} epochs. Early stopping.") f"\n⚠ No improvement in validation score for {patience_counter} epochs. Early stopping.")
return return
tqdm.write( tqdm.write(
f" No improvement (patience: {patient_counter}/{self.cfg.patient_epochs})") f" No improvement (patience: {patience_counter}/{self.cfg.patience})")
torch.save({
"epoch": epoch,
"global_step": self.global_step,
"model_state_dict": self.ema_model.state_dict(),
"criterion_state_dict": self.criterion.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
}, self.last_path)
tqdm.write("\n🎉 Training complete!") tqdm.write("\n🎉 Training complete!")