Enhance Trainer class: add logging for WeibullNLLLoss parameters during training and validation

This commit is contained in:
2026-01-09 13:48:36 +08:00
parent 8723bf7600
commit 880fd53a4b

View File

@@ -308,6 +308,8 @@ class Trainer:
total_train_pairs = 0 total_train_pairs = 0
total_train_nll = 0.0 total_train_nll = 0.0
total_train_reg = 0.0 total_train_reg = 0.0
total_train_log_scale_sq = 0.0
total_train_log_shape_sq = 0.0
pbar = tqdm(self.train_loader, pbar = tqdm(self.train_loader,
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100) desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
batch_count = 0 batch_count = 0
@@ -341,6 +343,17 @@ class Trainer:
b_prev=b_prev, b_prev=b_prev,
t_prev=t_prev, t_prev=t_prev,
) )
if isinstance(self.criterion, WeibullNLLLoss):
eps = float(self.criterion.eps)
shapes = torch.nn.functional.softplus(logits[..., 0]) + eps
scales = torch.nn.functional.softplus(logits[..., 1]) + eps
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
else:
log_scale_sq = None
log_shape_sq = None
target_event = event_seq[b_next, t_next] - 2 target_event = event_seq[b_next, t_next] - 2
nll_vec, reg = self.criterion( nll_vec, reg = self.criterion(
logits, logits,
@@ -354,6 +367,10 @@ class Trainer:
total_train_pairs += num_pairs total_train_pairs += num_pairs
total_train_nll += nll_vec.sum().item() total_train_nll += nll_vec.sum().item()
total_train_reg += reg.item() * num_pairs total_train_reg += reg.item() * num_pairs
if log_scale_sq is not None:
total_train_log_scale_sq += log_scale_sq.item() * num_pairs
if log_shape_sq is not None:
total_train_log_shape_sq += log_shape_sq.item() * num_pairs
avg_train_nll = total_train_nll / total_train_pairs avg_train_nll = total_train_nll / total_train_pairs
avg_train_reg = total_train_reg / total_train_pairs avg_train_reg = total_train_reg / total_train_pairs
pbar.set_postfix({ pbar.set_postfix({
@@ -374,11 +391,23 @@ class Trainer:
train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0 train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0 train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
train_log_scale_sq = (
total_train_log_scale_sq / total_train_pairs
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
else None
)
train_log_shape_sq = (
total_train_log_shape_sq / total_train_pairs
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
else None
)
self.model.eval() self.model.eval()
total_val_pairs = 0 total_val_pairs = 0
total_val_nll = 0.0 total_val_nll = 0.0
total_val_reg = 0.0 total_val_reg = 0.0
total_val_log_scale_sq = 0.0
total_val_log_shape_sq = 0.0
with torch.no_grad(): with torch.no_grad():
val_pbar = tqdm(self.val_loader, desc="Validation") val_pbar = tqdm(self.val_loader, desc="Validation")
for batch in val_pbar: for batch in val_pbar:
@@ -408,6 +437,19 @@ class Trainer:
b_prev=b_prev, b_prev=b_prev,
t_prev=t_prev t_prev=t_prev
) )
if isinstance(self.criterion, WeibullNLLLoss):
eps = float(self.criterion.eps)
shapes = torch.nn.functional.softplus(
logits[..., 0]) + eps
scales = torch.nn.functional.softplus(
logits[..., 1]) + eps
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
else:
log_scale_sq = None
log_shape_sq = None
target_events = event_seq[b_next, t_next] - 2 target_events = event_seq[b_next, t_next] - 2
nll, reg = self.criterion( nll, reg = self.criterion(
logits, logits,
@@ -418,6 +460,10 @@ class Trainer:
batch_nll_sum = nll.sum().item() batch_nll_sum = nll.sum().item()
total_val_nll += batch_nll_sum total_val_nll += batch_nll_sum
total_val_reg += reg.item() * num_pairs total_val_reg += reg.item() * num_pairs
if log_scale_sq is not None:
total_val_log_scale_sq += log_scale_sq.item() * num_pairs
if log_shape_sq is not None:
total_val_log_shape_sq += log_shape_sq.item() * num_pairs
total_val_pairs += num_pairs total_val_pairs += num_pairs
current_val_avg_nll = total_val_nll / \ current_val_avg_nll = total_val_nll / \
@@ -432,6 +478,16 @@ class Trainer:
val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0 val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0
val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0 val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0
val_log_scale_sq = (
total_val_log_scale_sq / total_val_pairs
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
else None
)
val_log_shape_sq = (
total_val_log_shape_sq / total_val_pairs
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
else None
)
history.append({ history.append({
"epoch": epoch, "epoch": epoch,
@@ -440,6 +496,10 @@ class Trainer:
"val_nll": val_nll, "val_nll": val_nll,
"val_reg": val_reg, "val_reg": val_reg,
"delta_scale": delta_scale, "delta_scale": delta_scale,
"train_log_scale_sq": train_log_scale_sq,
"train_log_shape_sq": train_log_shape_sq,
"val_log_scale_sq": val_log_scale_sq,
"val_log_shape_sq": val_log_shape_sq,
}) })
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:") tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
@@ -449,6 +509,16 @@ class Trainer:
tqdm.write(f" Val Reg: {val_reg:.4f}") tqdm.write(f" Val Reg: {val_reg:.4f}")
if delta_scale is not None: if delta_scale is not None:
tqdm.write(f" Delta scale: {delta_scale:.6g}") tqdm.write(f" Delta scale: {delta_scale:.6g}")
if train_log_scale_sq is not None and train_log_shape_sq is not None:
tqdm.write(
f" Train log(scale+eps)^2 mean: {train_log_scale_sq:.6g}")
tqdm.write(
f" Train log(shape+eps)^2 mean: {train_log_shape_sq:.6g}")
if val_log_scale_sq is not None and val_log_shape_sq is not None:
tqdm.write(
f" Val log(scale+eps)^2 mean: {val_log_scale_sq:.6g}")
tqdm.write(
f" Val log(shape+eps)^2 mean: {val_log_shape_sq:.6g}")
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f: with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
json.dump(history, f, indent=4) json.dump(history, f, indent=4)