Refactor Trainer class: improve training statistics calculation and logging for NLL and regularization

This commit is contained in:
2026-01-09 12:49:29 +08:00
parent aff0fe480b
commit b54c54a60b

View File

@@ -289,8 +289,9 @@ class Trainer:
patience_counter = 0
for epoch in range(1, self.cfg.max_epochs + 1):
self.model.train()
running_nll = 0.0
running_reg = 0.0
total_train_pairs = 0
total_train_nll = 0.0
total_train_reg = 0.0
pbar = tqdm(self.train_loader,
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
batch_count = 0
@@ -311,6 +312,7 @@ class Trainer:
if res is None:
continue
dt, b_prev, t_prev, b_next, t_next = res
num_pairs = dt.size(0)
self.optimizer.zero_grad()
lr = self.compute_lr(self.global_step)
self.optimizer.param_groups[0]['lr'] = lr
@@ -333,12 +335,15 @@ class Trainer:
nll = nll_vec.mean()
loss = nll + reg
batch_count += 1
running_nll += nll.item()
running_reg += reg.item()
total_train_pairs += num_pairs
total_train_nll += nll_vec.sum().item()
total_train_reg += reg.item() * num_pairs
avg_train_nll = total_train_nll / total_train_pairs
avg_train_reg = total_train_reg / total_train_pairs
pbar.set_postfix({
"lr": lr,
"NLL": running_nll / batch_count,
"Reg": running_reg / batch_count,
"NLL": avg_train_nll,
"Reg": avg_train_reg,
})
loss.backward()
if self.cfg.grad_clip > 0:
@@ -351,8 +356,8 @@ class Trainer:
print("No valid batches in this epoch, skipping validation.")
continue
train_nll = running_nll / batch_count
train_reg = running_reg / batch_count
train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
self.model.eval()
total_val_pairs = 0
@@ -422,7 +427,9 @@ class Trainer:
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
tqdm.write(f" Train NLL: {train_nll:.4f}")
tqdm.write(f" Train Reg: {train_reg:.4f}")
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
tqdm.write(f" Val Reg: {val_reg:.4f}")
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
json.dump(history, f, indent=4)