From 880fd53a4b3a077e86d94bd1bbf6472bbfd08869 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 9 Jan 2026 13:48:36 +0800 Subject: [PATCH] Enhance Trainer class: add logging for WeibullNLLLoss parameters during training and validation --- train.py | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/train.py b/train.py index c407c45..f584a05 100644 --- a/train.py +++ b/train.py @@ -308,6 +308,8 @@ class Trainer: total_train_pairs = 0 total_train_nll = 0.0 total_train_reg = 0.0 + total_train_log_scale_sq = 0.0 + total_train_log_shape_sq = 0.0 pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100) batch_count = 0 @@ -341,6 +343,17 @@ class Trainer: b_prev=b_prev, t_prev=t_prev, ) + + if isinstance(self.criterion, WeibullNLLLoss): + eps = float(self.criterion.eps) + shapes = torch.nn.functional.softplus(logits[..., 0]) + eps + scales = torch.nn.functional.softplus(logits[..., 1]) + eps + log_scale_sq = (torch.log(scales + eps) ** 2).mean() + log_shape_sq = (torch.log(shapes + eps) ** 2).mean() + else: + log_scale_sq = None + log_shape_sq = None + target_event = event_seq[b_next, t_next] - 2 nll_vec, reg = self.criterion( logits, @@ -354,6 +367,10 @@ class Trainer: total_train_pairs += num_pairs total_train_nll += nll_vec.sum().item() total_train_reg += reg.item() * num_pairs + if log_scale_sq is not None: + total_train_log_scale_sq += log_scale_sq.item() * num_pairs + if log_shape_sq is not None: + total_train_log_shape_sq += log_shape_sq.item() * num_pairs avg_train_nll = total_train_nll / total_train_pairs avg_train_reg = total_train_reg / total_train_pairs pbar.set_postfix({ @@ -374,11 +391,23 @@ class Trainer: train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0 train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0 + train_log_scale_sq = ( + total_train_log_scale_sq / total_train_pairs + if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss) + else None + ) + train_log_shape_sq = ( + total_train_log_shape_sq / total_train_pairs + if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss) + else None + ) self.model.eval() total_val_pairs = 0 total_val_nll = 0.0 total_val_reg = 0.0 + total_val_log_scale_sq = 0.0 + total_val_log_shape_sq = 0.0 with torch.no_grad(): val_pbar = tqdm(self.val_loader, desc="Validation") for batch in val_pbar: @@ -408,6 +437,19 @@ class Trainer: b_prev=b_prev, t_prev=t_prev ) + + if isinstance(self.criterion, WeibullNLLLoss): + eps = float(self.criterion.eps) + shapes = torch.nn.functional.softplus( + logits[..., 0]) + eps + scales = torch.nn.functional.softplus( + logits[..., 1]) + eps + log_scale_sq = (torch.log(scales + eps) ** 2).mean() + log_shape_sq = (torch.log(shapes + eps) ** 2).mean() + else: + log_scale_sq = None + log_shape_sq = None + target_events = event_seq[b_next, t_next] - 2 nll, reg = self.criterion( logits, @@ -418,6 +460,10 @@ class Trainer: batch_nll_sum = nll.sum().item() total_val_nll += batch_nll_sum total_val_reg += reg.item() * num_pairs + if log_scale_sq is not None: + total_val_log_scale_sq += log_scale_sq.item() * num_pairs + if log_shape_sq is not None: + total_val_log_shape_sq += log_shape_sq.item() * num_pairs total_val_pairs += num_pairs current_val_avg_nll = total_val_nll / \ @@ -432,6 +478,16 @@ class Trainer: val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0 val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0 + val_log_scale_sq = ( + total_val_log_scale_sq / total_val_pairs + if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss) + else None + ) + val_log_shape_sq = ( + total_val_log_shape_sq / total_val_pairs + if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss) + else None + ) history.append({ "epoch": epoch, @@ -440,6 +496,10 @@ class Trainer: "val_nll": val_nll, "val_reg": val_reg, "delta_scale": delta_scale, + "train_log_scale_sq": train_log_scale_sq, + "train_log_shape_sq": train_log_shape_sq, + "val_log_scale_sq": val_log_scale_sq, + "val_log_shape_sq": val_log_shape_sq, }) tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:") @@ -449,6 +509,16 @@ class Trainer: tqdm.write(f" Val Reg: {val_reg:.4f}") if delta_scale is not None: tqdm.write(f" Delta scale: {delta_scale:.6g}") + if train_log_scale_sq is not None and train_log_shape_sq is not None: + tqdm.write( + f" Train log(scale+eps)^2 mean: {train_log_scale_sq:.6g}") + tqdm.write( + f" Train log(shape+eps)^2 mean: {train_log_shape_sq:.6g}") + if val_log_scale_sq is not None and val_log_shape_sq is not None: + tqdm.write( + f" Val log(scale+eps)^2 mean: {val_log_scale_sq:.6g}") + tqdm.write( + f" Val log(shape+eps)^2 mean: {val_log_shape_sq:.6g}") with open(os.path.join(self.out_dir, "training_history.json"), "w") as f: json.dump(history, f, indent=4)