Refactor Trainer class: improve training statistics calculation and logging for NLL and regularization
This commit is contained in:
23
train.py
23
train.py
@@ -289,8 +289,9 @@ class Trainer:
|
||||
patience_counter = 0
|
||||
for epoch in range(1, self.cfg.max_epochs + 1):
|
||||
self.model.train()
|
||||
running_nll = 0.0
|
||||
running_reg = 0.0
|
||||
total_train_pairs = 0
|
||||
total_train_nll = 0.0
|
||||
total_train_reg = 0.0
|
||||
pbar = tqdm(self.train_loader,
|
||||
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
|
||||
batch_count = 0
|
||||
@@ -311,6 +312,7 @@ class Trainer:
|
||||
if res is None:
|
||||
continue
|
||||
dt, b_prev, t_prev, b_next, t_next = res
|
||||
num_pairs = dt.size(0)
|
||||
self.optimizer.zero_grad()
|
||||
lr = self.compute_lr(self.global_step)
|
||||
self.optimizer.param_groups[0]['lr'] = lr
|
||||
@@ -333,12 +335,15 @@ class Trainer:
|
||||
nll = nll_vec.mean()
|
||||
loss = nll + reg
|
||||
batch_count += 1
|
||||
running_nll += nll.item()
|
||||
running_reg += reg.item()
|
||||
total_train_pairs += num_pairs
|
||||
total_train_nll += nll_vec.sum().item()
|
||||
total_train_reg += reg.item() * num_pairs
|
||||
avg_train_nll = total_train_nll / total_train_pairs
|
||||
avg_train_reg = total_train_reg / total_train_pairs
|
||||
pbar.set_postfix({
|
||||
"lr": lr,
|
||||
"NLL": running_nll / batch_count,
|
||||
"Reg": running_reg / batch_count,
|
||||
"NLL": avg_train_nll,
|
||||
"Reg": avg_train_reg,
|
||||
})
|
||||
loss.backward()
|
||||
if self.cfg.grad_clip > 0:
|
||||
@@ -351,8 +356,8 @@ class Trainer:
|
||||
print("No valid batches in this epoch, skipping validation.")
|
||||
continue
|
||||
|
||||
train_nll = running_nll / batch_count
|
||||
train_reg = running_reg / batch_count
|
||||
train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
|
||||
train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
|
||||
|
||||
self.model.eval()
|
||||
total_val_pairs = 0
|
||||
@@ -422,7 +427,9 @@ class Trainer:
|
||||
|
||||
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
|
||||
tqdm.write(f" Train NLL: {train_nll:.4f}")
|
||||
tqdm.write(f" Train Reg: {train_reg:.4f}")
|
||||
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
|
||||
tqdm.write(f" Val Reg: {val_reg:.4f}")
|
||||
|
||||
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
|
||||
json.dump(history, f, indent=4)
|
||||
|
||||
Reference in New Issue
Block a user