Refactor training logic to improve early stopping mechanism and variable naming
This commit is contained in:
21
train.py
21
train.py
@@ -294,7 +294,7 @@ class Trainer:
|
||||
|
||||
def train(self) -> None:
|
||||
history = []
|
||||
best_val_loss = float('inf')
|
||||
best_val_score = float('inf')
|
||||
patience_counter = 0
|
||||
for epoch in range(1, self.cfg.max_epochs + 1):
|
||||
self.model.train()
|
||||
@@ -302,6 +302,7 @@ class Trainer:
|
||||
running_reg = 0.0
|
||||
pbar = tqdm(self.train_loader,
|
||||
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
|
||||
batch_count = 0
|
||||
for batch in pbar:
|
||||
(
|
||||
event_seq,
|
||||
@@ -445,7 +446,7 @@ class Trainer:
|
||||
# Check for improvement
|
||||
if val_nll < best_val_score:
|
||||
best_val_score = val_nll
|
||||
patient_counter = 0
|
||||
patience_counter = 0
|
||||
tqdm.write(" ✓ New best validation score. Saving checkpoint.")
|
||||
|
||||
torch.save({
|
||||
@@ -456,21 +457,13 @@ class Trainer:
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
}, self.best_path)
|
||||
else:
|
||||
patient_counter += 1
|
||||
if epoch+1 >= self.cfg.min_epochs and patient_counter >= self.cfg.patient_epochs:
|
||||
patience_counter += 1
|
||||
if epoch >= self.cfg.warmup_epochs and patience_counter >= self.cfg.patience:
|
||||
tqdm.write(
|
||||
f"\n⚠ No improvement in validation score for {patient_counter} epochs. Early stopping.")
|
||||
f"\n⚠ No improvement in validation score for {patience_counter} epochs. Early stopping.")
|
||||
return
|
||||
tqdm.write(
|
||||
f" No improvement (patience: {patient_counter}/{self.cfg.patient_epochs})")
|
||||
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"global_step": self.global_step,
|
||||
"model_state_dict": self.ema_model.state_dict(),
|
||||
"criterion_state_dict": self.criterion.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
}, self.last_path)
|
||||
f" No improvement (patience: {patience_counter}/{self.cfg.patience})")
|
||||
|
||||
tqdm.write("\n🎉 Training complete!")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user