Refactor Trainer class: improve training statistics calculation and logging for NLL and regularization

This commit is contained in:
2026-01-09 12:49:29 +08:00
parent aff0fe480b
commit b54c54a60b

View File

@@ -289,8 +289,9 @@ class Trainer:
patience_counter = 0 patience_counter = 0
for epoch in range(1, self.cfg.max_epochs + 1): for epoch in range(1, self.cfg.max_epochs + 1):
self.model.train() self.model.train()
running_nll = 0.0 total_train_pairs = 0
running_reg = 0.0 total_train_nll = 0.0
total_train_reg = 0.0
pbar = tqdm(self.train_loader, pbar = tqdm(self.train_loader,
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100) desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
batch_count = 0 batch_count = 0
@@ -311,6 +312,7 @@ class Trainer:
if res is None: if res is None:
continue continue
dt, b_prev, t_prev, b_next, t_next = res dt, b_prev, t_prev, b_next, t_next = res
num_pairs = dt.size(0)
self.optimizer.zero_grad() self.optimizer.zero_grad()
lr = self.compute_lr(self.global_step) lr = self.compute_lr(self.global_step)
self.optimizer.param_groups[0]['lr'] = lr self.optimizer.param_groups[0]['lr'] = lr
@@ -333,12 +335,15 @@ class Trainer:
nll = nll_vec.mean() nll = nll_vec.mean()
loss = nll + reg loss = nll + reg
batch_count += 1 batch_count += 1
running_nll += nll.item() total_train_pairs += num_pairs
running_reg += reg.item() total_train_nll += nll_vec.sum().item()
total_train_reg += reg.item() * num_pairs
avg_train_nll = total_train_nll / total_train_pairs
avg_train_reg = total_train_reg / total_train_pairs
pbar.set_postfix({ pbar.set_postfix({
"lr": lr, "lr": lr,
"NLL": running_nll / batch_count, "NLL": avg_train_nll,
"Reg": running_reg / batch_count, "Reg": avg_train_reg,
}) })
loss.backward() loss.backward()
if self.cfg.grad_clip > 0: if self.cfg.grad_clip > 0:
@@ -351,8 +356,8 @@ class Trainer:
print("No valid batches in this epoch, skipping validation.") print("No valid batches in this epoch, skipping validation.")
continue continue
train_nll = running_nll / batch_count train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
train_reg = running_reg / batch_count train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
self.model.eval() self.model.eval()
total_val_pairs = 0 total_val_pairs = 0
@@ -422,7 +427,9 @@ class Trainer:
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:") tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
tqdm.write(f" Train NLL: {train_nll:.4f}") tqdm.write(f" Train NLL: {train_nll:.4f}")
tqdm.write(f" Train Reg: {train_reg:.4f}")
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC") tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
tqdm.write(f" Val Reg: {val_reg:.4f}")
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f: with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
json.dump(history, f, indent=4) json.dump(history, f, indent=4)