Enhance Trainer class: add delta_scale logging for improved training statistics
This commit is contained in:
14
train.py
14
train.py
@@ -293,6 +293,17 @@ class Trainer:
|
||||
best_val_score = float('inf')
|
||||
patience_counter = 0
|
||||
for epoch in range(1, self.cfg.max_epochs + 1):
|
||||
model_for_logging = self.model.module if hasattr(
|
||||
self.model, "module") else self.model
|
||||
delta_scale = None
|
||||
theta_proj = getattr(model_for_logging, "theta_proj", None)
|
||||
if theta_proj is not None and hasattr(theta_proj, "delta_scale"):
|
||||
try:
|
||||
delta_scale = float(
|
||||
theta_proj.delta_scale.detach().cpu().item())
|
||||
except Exception:
|
||||
delta_scale = None
|
||||
|
||||
self.model.train()
|
||||
total_train_pairs = 0
|
||||
total_train_nll = 0.0
|
||||
@@ -428,6 +439,7 @@ class Trainer:
|
||||
"train_reg": train_reg,
|
||||
"val_nll": val_nll,
|
||||
"val_reg": val_reg,
|
||||
"delta_scale": delta_scale,
|
||||
})
|
||||
|
||||
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
|
||||
@@ -435,6 +447,8 @@ class Trainer:
|
||||
tqdm.write(f" Train Reg: {train_reg:.4f}")
|
||||
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
|
||||
tqdm.write(f" Val Reg: {val_reg:.4f}")
|
||||
if delta_scale is not None:
|
||||
tqdm.write(f" Delta scale: {delta_scale:.6g}")
|
||||
|
||||
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
|
||||
json.dump(history, f, indent=4)
|
||||
|
||||
Reference in New Issue
Block a user