From 8723bf7600a410c9e4a1ba53df67f0d4630a81d5 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 9 Jan 2026 13:28:11 +0800 Subject: [PATCH] Enhance Trainer class: add delta_scale logging for improved training statistics --- train.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/train.py b/train.py index 5d37deb..c407c45 100644 --- a/train.py +++ b/train.py @@ -293,6 +293,17 @@ class Trainer: best_val_score = float('inf') patience_counter = 0 for epoch in range(1, self.cfg.max_epochs + 1): + model_for_logging = self.model.module if hasattr( + self.model, "module") else self.model + delta_scale = None + theta_proj = getattr(model_for_logging, "theta_proj", None) + if theta_proj is not None and hasattr(theta_proj, "delta_scale"): + try: + delta_scale = float( + theta_proj.delta_scale.detach().cpu().item()) + except Exception: + delta_scale = None + self.model.train() total_train_pairs = 0 total_train_nll = 0.0 @@ -428,6 +439,7 @@ class Trainer: "train_reg": train_reg, "val_nll": val_nll, "val_reg": val_reg, + "delta_scale": delta_scale, }) tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:") @@ -435,6 +447,8 @@ class Trainer: tqdm.write(f" Train Reg: {train_reg:.4f}") tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC") tqdm.write(f" Val Reg: {val_reg:.4f}") + if delta_scale is not None: + tqdm.write(f" Delta scale: {delta_scale:.6g}") with open(os.path.join(self.out_dir, "training_history.json"), "w") as f: json.dump(history, f, indent=4)