Enhance Trainer class: add delta_scale logging for improved training statistics

This commit is contained in:
2026-01-09 13:28:11 +08:00
parent dc34d51864
commit 8723bf7600

View File

@@ -293,6 +293,17 @@ class Trainer:
best_val_score = float('inf')
patience_counter = 0
for epoch in range(1, self.cfg.max_epochs + 1):
model_for_logging = self.model.module if hasattr(
self.model, "module") else self.model
delta_scale = None
theta_proj = getattr(model_for_logging, "theta_proj", None)
if theta_proj is not None and hasattr(theta_proj, "delta_scale"):
try:
delta_scale = float(
theta_proj.delta_scale.detach().cpu().item())
except Exception:
delta_scale = None
self.model.train()
total_train_pairs = 0
total_train_nll = 0.0
@@ -428,6 +439,7 @@ class Trainer:
"train_reg": train_reg,
"val_nll": val_nll,
"val_reg": val_reg,
"delta_scale": delta_scale,
})
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
@@ -435,6 +447,8 @@ class Trainer:
tqdm.write(f" Train Reg: {train_reg:.4f}")
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
tqdm.write(f" Val Reg: {val_reg:.4f}")
if delta_scale is not None:
tqdm.write(f" Delta scale: {delta_scale:.6g}")
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
json.dump(history, f, indent=4)