From b54c54a60b4c0395918ffc74e3582b77949dc5a5 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 9 Jan 2026 12:49:29 +0800 Subject: [PATCH] Refactor Trainer class: improve training statistics calculation and logging for NLL and regularization --- train.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index 30a4452..eef93fd 100644 --- a/train.py +++ b/train.py @@ -289,8 +289,9 @@ class Trainer: patience_counter = 0 for epoch in range(1, self.cfg.max_epochs + 1): self.model.train() - running_nll = 0.0 - running_reg = 0.0 + total_train_pairs = 0 + total_train_nll = 0.0 + total_train_reg = 0.0 pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100) batch_count = 0 @@ -311,6 +312,7 @@ class Trainer: if res is None: continue dt, b_prev, t_prev, b_next, t_next = res + num_pairs = dt.size(0) self.optimizer.zero_grad() lr = self.compute_lr(self.global_step) self.optimizer.param_groups[0]['lr'] = lr @@ -333,12 +335,15 @@ class Trainer: nll = nll_vec.mean() loss = nll + reg batch_count += 1 - running_nll += nll.item() - running_reg += reg.item() + total_train_pairs += num_pairs + total_train_nll += nll_vec.sum().item() + total_train_reg += reg.item() * num_pairs + avg_train_nll = total_train_nll / total_train_pairs + avg_train_reg = total_train_reg / total_train_pairs pbar.set_postfix({ "lr": lr, - "NLL": running_nll / batch_count, - "Reg": running_reg / batch_count, + "NLL": avg_train_nll, + "Reg": avg_train_reg, }) loss.backward() if self.cfg.grad_clip > 0: @@ -351,8 +356,8 @@ class Trainer: print("No valid batches in this epoch, skipping validation.") continue - train_nll = running_nll / batch_count - train_reg = running_reg / batch_count + train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0 + train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0 self.model.eval() total_val_pairs = 0 @@ -422,7 +427,9 @@ class Trainer: tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:") tqdm.write(f" Train NLL: {train_nll:.4f}") + tqdm.write(f" Train Reg: {train_reg:.4f}") tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC") + tqdm.write(f" Val Reg: {val_reg:.4f}") with open(os.path.join(self.out_dir, "training_history.json"), "w") as f: json.dump(history, f, indent=4)