Refactor training logic to improve early stopping mechanism and variable naming
This commit is contained in:
21
train.py
21
train.py
@@ -294,7 +294,7 @@ class Trainer:
|
|||||||
|
|
||||||
def train(self) -> None:
|
def train(self) -> None:
|
||||||
history = []
|
history = []
|
||||||
best_val_loss = float('inf')
|
best_val_score = float('inf')
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
for epoch in range(1, self.cfg.max_epochs + 1):
|
for epoch in range(1, self.cfg.max_epochs + 1):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
@@ -302,6 +302,7 @@ class Trainer:
|
|||||||
running_reg = 0.0
|
running_reg = 0.0
|
||||||
pbar = tqdm(self.train_loader,
|
pbar = tqdm(self.train_loader,
|
||||||
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
|
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
|
||||||
|
batch_count = 0
|
||||||
for batch in pbar:
|
for batch in pbar:
|
||||||
(
|
(
|
||||||
event_seq,
|
event_seq,
|
||||||
@@ -445,7 +446,7 @@ class Trainer:
|
|||||||
# Check for improvement
|
# Check for improvement
|
||||||
if val_nll < best_val_score:
|
if val_nll < best_val_score:
|
||||||
best_val_score = val_nll
|
best_val_score = val_nll
|
||||||
patient_counter = 0
|
patience_counter = 0
|
||||||
tqdm.write(" ✓ New best validation score. Saving checkpoint.")
|
tqdm.write(" ✓ New best validation score. Saving checkpoint.")
|
||||||
|
|
||||||
torch.save({
|
torch.save({
|
||||||
@@ -456,21 +457,13 @@ class Trainer:
|
|||||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||||
}, self.best_path)
|
}, self.best_path)
|
||||||
else:
|
else:
|
||||||
patient_counter += 1
|
patience_counter += 1
|
||||||
if epoch+1 >= self.cfg.min_epochs and patient_counter >= self.cfg.patient_epochs:
|
if epoch >= self.cfg.warmup_epochs and patience_counter >= self.cfg.patience:
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"\n⚠ No improvement in validation score for {patient_counter} epochs. Early stopping.")
|
f"\n⚠ No improvement in validation score for {patience_counter} epochs. Early stopping.")
|
||||||
return
|
return
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f" No improvement (patience: {patient_counter}/{self.cfg.patient_epochs})")
|
f" No improvement (patience: {patience_counter}/{self.cfg.patience})")
|
||||||
|
|
||||||
torch.save({
|
|
||||||
"epoch": epoch,
|
|
||||||
"global_step": self.global_step,
|
|
||||||
"model_state_dict": self.ema_model.state_dict(),
|
|
||||||
"criterion_state_dict": self.criterion.state_dict(),
|
|
||||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
||||||
}, self.last_path)
|
|
||||||
|
|
||||||
tqdm.write("\n🎉 Training complete!")
|
tqdm.write("\n🎉 Training complete!")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user