Enhance Trainer class: add delta_scale logging for improved training statistics
This commit is contained in:
14
train.py
14
train.py
@@ -293,6 +293,17 @@ class Trainer:
|
|||||||
best_val_score = float('inf')
|
best_val_score = float('inf')
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
for epoch in range(1, self.cfg.max_epochs + 1):
|
for epoch in range(1, self.cfg.max_epochs + 1):
|
||||||
|
model_for_logging = self.model.module if hasattr(
|
||||||
|
self.model, "module") else self.model
|
||||||
|
delta_scale = None
|
||||||
|
theta_proj = getattr(model_for_logging, "theta_proj", None)
|
||||||
|
if theta_proj is not None and hasattr(theta_proj, "delta_scale"):
|
||||||
|
try:
|
||||||
|
delta_scale = float(
|
||||||
|
theta_proj.delta_scale.detach().cpu().item())
|
||||||
|
except Exception:
|
||||||
|
delta_scale = None
|
||||||
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
total_train_pairs = 0
|
total_train_pairs = 0
|
||||||
total_train_nll = 0.0
|
total_train_nll = 0.0
|
||||||
@@ -428,6 +439,7 @@ class Trainer:
|
|||||||
"train_reg": train_reg,
|
"train_reg": train_reg,
|
||||||
"val_nll": val_nll,
|
"val_nll": val_nll,
|
||||||
"val_reg": val_reg,
|
"val_reg": val_reg,
|
||||||
|
"delta_scale": delta_scale,
|
||||||
})
|
})
|
||||||
|
|
||||||
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
|
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
|
||||||
@@ -435,6 +447,8 @@ class Trainer:
|
|||||||
tqdm.write(f" Train Reg: {train_reg:.4f}")
|
tqdm.write(f" Train Reg: {train_reg:.4f}")
|
||||||
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
|
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
|
||||||
tqdm.write(f" Val Reg: {val_reg:.4f}")
|
tqdm.write(f" Val Reg: {val_reg:.4f}")
|
||||||
|
if delta_scale is not None:
|
||||||
|
tqdm.write(f" Delta scale: {delta_scale:.6g}")
|
||||||
|
|
||||||
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
|
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
|
||||||
json.dump(history, f, indent=4)
|
json.dump(history, f, indent=4)
|
||||||
|
|||||||
Reference in New Issue
Block a user