Enhance Trainer class: add logging for WeibullNLLLoss parameters during training and validation
This commit is contained in:
70
train.py
70
train.py
@@ -308,6 +308,8 @@ class Trainer:
|
||||
total_train_pairs = 0
|
||||
total_train_nll = 0.0
|
||||
total_train_reg = 0.0
|
||||
total_train_log_scale_sq = 0.0
|
||||
total_train_log_shape_sq = 0.0
|
||||
pbar = tqdm(self.train_loader,
|
||||
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
|
||||
batch_count = 0
|
||||
@@ -341,6 +343,17 @@ class Trainer:
|
||||
b_prev=b_prev,
|
||||
t_prev=t_prev,
|
||||
)
|
||||
|
||||
if isinstance(self.criterion, WeibullNLLLoss):
|
||||
eps = float(self.criterion.eps)
|
||||
shapes = torch.nn.functional.softplus(logits[..., 0]) + eps
|
||||
scales = torch.nn.functional.softplus(logits[..., 1]) + eps
|
||||
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
|
||||
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
|
||||
else:
|
||||
log_scale_sq = None
|
||||
log_shape_sq = None
|
||||
|
||||
target_event = event_seq[b_next, t_next] - 2
|
||||
nll_vec, reg = self.criterion(
|
||||
logits,
|
||||
@@ -354,6 +367,10 @@ class Trainer:
|
||||
total_train_pairs += num_pairs
|
||||
total_train_nll += nll_vec.sum().item()
|
||||
total_train_reg += reg.item() * num_pairs
|
||||
if log_scale_sq is not None:
|
||||
total_train_log_scale_sq += log_scale_sq.item() * num_pairs
|
||||
if log_shape_sq is not None:
|
||||
total_train_log_shape_sq += log_shape_sq.item() * num_pairs
|
||||
avg_train_nll = total_train_nll / total_train_pairs
|
||||
avg_train_reg = total_train_reg / total_train_pairs
|
||||
pbar.set_postfix({
|
||||
@@ -374,11 +391,23 @@ class Trainer:
|
||||
|
||||
train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
|
||||
train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
|
||||
train_log_scale_sq = (
|
||||
total_train_log_scale_sq / total_train_pairs
|
||||
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||
else None
|
||||
)
|
||||
train_log_shape_sq = (
|
||||
total_train_log_shape_sq / total_train_pairs
|
||||
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||
else None
|
||||
)
|
||||
|
||||
self.model.eval()
|
||||
total_val_pairs = 0
|
||||
total_val_nll = 0.0
|
||||
total_val_reg = 0.0
|
||||
total_val_log_scale_sq = 0.0
|
||||
total_val_log_shape_sq = 0.0
|
||||
with torch.no_grad():
|
||||
val_pbar = tqdm(self.val_loader, desc="Validation")
|
||||
for batch in val_pbar:
|
||||
@@ -408,6 +437,19 @@ class Trainer:
|
||||
b_prev=b_prev,
|
||||
t_prev=t_prev
|
||||
)
|
||||
|
||||
if isinstance(self.criterion, WeibullNLLLoss):
|
||||
eps = float(self.criterion.eps)
|
||||
shapes = torch.nn.functional.softplus(
|
||||
logits[..., 0]) + eps
|
||||
scales = torch.nn.functional.softplus(
|
||||
logits[..., 1]) + eps
|
||||
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
|
||||
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
|
||||
else:
|
||||
log_scale_sq = None
|
||||
log_shape_sq = None
|
||||
|
||||
target_events = event_seq[b_next, t_next] - 2
|
||||
nll, reg = self.criterion(
|
||||
logits,
|
||||
@@ -418,6 +460,10 @@ class Trainer:
|
||||
batch_nll_sum = nll.sum().item()
|
||||
total_val_nll += batch_nll_sum
|
||||
total_val_reg += reg.item() * num_pairs
|
||||
if log_scale_sq is not None:
|
||||
total_val_log_scale_sq += log_scale_sq.item() * num_pairs
|
||||
if log_shape_sq is not None:
|
||||
total_val_log_shape_sq += log_shape_sq.item() * num_pairs
|
||||
total_val_pairs += num_pairs
|
||||
|
||||
current_val_avg_nll = total_val_nll / \
|
||||
@@ -432,6 +478,16 @@ class Trainer:
|
||||
|
||||
val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0
|
||||
val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0
|
||||
val_log_scale_sq = (
|
||||
total_val_log_scale_sq / total_val_pairs
|
||||
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||
else None
|
||||
)
|
||||
val_log_shape_sq = (
|
||||
total_val_log_shape_sq / total_val_pairs
|
||||
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||
else None
|
||||
)
|
||||
|
||||
history.append({
|
||||
"epoch": epoch,
|
||||
@@ -440,6 +496,10 @@ class Trainer:
|
||||
"val_nll": val_nll,
|
||||
"val_reg": val_reg,
|
||||
"delta_scale": delta_scale,
|
||||
"train_log_scale_sq": train_log_scale_sq,
|
||||
"train_log_shape_sq": train_log_shape_sq,
|
||||
"val_log_scale_sq": val_log_scale_sq,
|
||||
"val_log_shape_sq": val_log_shape_sq,
|
||||
})
|
||||
|
||||
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
|
||||
@@ -449,6 +509,16 @@ class Trainer:
|
||||
tqdm.write(f" Val Reg: {val_reg:.4f}")
|
||||
if delta_scale is not None:
|
||||
tqdm.write(f" Delta scale: {delta_scale:.6g}")
|
||||
if train_log_scale_sq is not None and train_log_shape_sq is not None:
|
||||
tqdm.write(
|
||||
f" Train log(scale+eps)^2 mean: {train_log_scale_sq:.6g}")
|
||||
tqdm.write(
|
||||
f" Train log(shape+eps)^2 mean: {train_log_shape_sq:.6g}")
|
||||
if val_log_scale_sq is not None and val_log_shape_sq is not None:
|
||||
tqdm.write(
|
||||
f" Val log(scale+eps)^2 mean: {val_log_scale_sq:.6g}")
|
||||
tqdm.write(
|
||||
f" Val log(shape+eps)^2 mean: {val_log_shape_sq:.6g}")
|
||||
|
||||
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
|
||||
json.dump(history, f, indent=4)
|
||||
|
||||
Reference in New Issue
Block a user