Enhance Trainer class: add logging for WeibullNLLLoss parameters during training and validation
This commit is contained in:
70
train.py
70
train.py
@@ -308,6 +308,8 @@ class Trainer:
|
|||||||
total_train_pairs = 0
|
total_train_pairs = 0
|
||||||
total_train_nll = 0.0
|
total_train_nll = 0.0
|
||||||
total_train_reg = 0.0
|
total_train_reg = 0.0
|
||||||
|
total_train_log_scale_sq = 0.0
|
||||||
|
total_train_log_shape_sq = 0.0
|
||||||
pbar = tqdm(self.train_loader,
|
pbar = tqdm(self.train_loader,
|
||||||
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
|
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
|
||||||
batch_count = 0
|
batch_count = 0
|
||||||
@@ -341,6 +343,17 @@ class Trainer:
|
|||||||
b_prev=b_prev,
|
b_prev=b_prev,
|
||||||
t_prev=t_prev,
|
t_prev=t_prev,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(self.criterion, WeibullNLLLoss):
|
||||||
|
eps = float(self.criterion.eps)
|
||||||
|
shapes = torch.nn.functional.softplus(logits[..., 0]) + eps
|
||||||
|
scales = torch.nn.functional.softplus(logits[..., 1]) + eps
|
||||||
|
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
|
||||||
|
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
|
||||||
|
else:
|
||||||
|
log_scale_sq = None
|
||||||
|
log_shape_sq = None
|
||||||
|
|
||||||
target_event = event_seq[b_next, t_next] - 2
|
target_event = event_seq[b_next, t_next] - 2
|
||||||
nll_vec, reg = self.criterion(
|
nll_vec, reg = self.criterion(
|
||||||
logits,
|
logits,
|
||||||
@@ -354,6 +367,10 @@ class Trainer:
|
|||||||
total_train_pairs += num_pairs
|
total_train_pairs += num_pairs
|
||||||
total_train_nll += nll_vec.sum().item()
|
total_train_nll += nll_vec.sum().item()
|
||||||
total_train_reg += reg.item() * num_pairs
|
total_train_reg += reg.item() * num_pairs
|
||||||
|
if log_scale_sq is not None:
|
||||||
|
total_train_log_scale_sq += log_scale_sq.item() * num_pairs
|
||||||
|
if log_shape_sq is not None:
|
||||||
|
total_train_log_shape_sq += log_shape_sq.item() * num_pairs
|
||||||
avg_train_nll = total_train_nll / total_train_pairs
|
avg_train_nll = total_train_nll / total_train_pairs
|
||||||
avg_train_reg = total_train_reg / total_train_pairs
|
avg_train_reg = total_train_reg / total_train_pairs
|
||||||
pbar.set_postfix({
|
pbar.set_postfix({
|
||||||
@@ -374,11 +391,23 @@ class Trainer:
|
|||||||
|
|
||||||
train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
|
train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
|
||||||
train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
|
train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
|
||||||
|
train_log_scale_sq = (
|
||||||
|
total_train_log_scale_sq / total_train_pairs
|
||||||
|
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
train_log_shape_sq = (
|
||||||
|
total_train_log_shape_sq / total_train_pairs
|
||||||
|
if total_train_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
total_val_pairs = 0
|
total_val_pairs = 0
|
||||||
total_val_nll = 0.0
|
total_val_nll = 0.0
|
||||||
total_val_reg = 0.0
|
total_val_reg = 0.0
|
||||||
|
total_val_log_scale_sq = 0.0
|
||||||
|
total_val_log_shape_sq = 0.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
val_pbar = tqdm(self.val_loader, desc="Validation")
|
val_pbar = tqdm(self.val_loader, desc="Validation")
|
||||||
for batch in val_pbar:
|
for batch in val_pbar:
|
||||||
@@ -408,6 +437,19 @@ class Trainer:
|
|||||||
b_prev=b_prev,
|
b_prev=b_prev,
|
||||||
t_prev=t_prev
|
t_prev=t_prev
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(self.criterion, WeibullNLLLoss):
|
||||||
|
eps = float(self.criterion.eps)
|
||||||
|
shapes = torch.nn.functional.softplus(
|
||||||
|
logits[..., 0]) + eps
|
||||||
|
scales = torch.nn.functional.softplus(
|
||||||
|
logits[..., 1]) + eps
|
||||||
|
log_scale_sq = (torch.log(scales + eps) ** 2).mean()
|
||||||
|
log_shape_sq = (torch.log(shapes + eps) ** 2).mean()
|
||||||
|
else:
|
||||||
|
log_scale_sq = None
|
||||||
|
log_shape_sq = None
|
||||||
|
|
||||||
target_events = event_seq[b_next, t_next] - 2
|
target_events = event_seq[b_next, t_next] - 2
|
||||||
nll, reg = self.criterion(
|
nll, reg = self.criterion(
|
||||||
logits,
|
logits,
|
||||||
@@ -418,6 +460,10 @@ class Trainer:
|
|||||||
batch_nll_sum = nll.sum().item()
|
batch_nll_sum = nll.sum().item()
|
||||||
total_val_nll += batch_nll_sum
|
total_val_nll += batch_nll_sum
|
||||||
total_val_reg += reg.item() * num_pairs
|
total_val_reg += reg.item() * num_pairs
|
||||||
|
if log_scale_sq is not None:
|
||||||
|
total_val_log_scale_sq += log_scale_sq.item() * num_pairs
|
||||||
|
if log_shape_sq is not None:
|
||||||
|
total_val_log_shape_sq += log_shape_sq.item() * num_pairs
|
||||||
total_val_pairs += num_pairs
|
total_val_pairs += num_pairs
|
||||||
|
|
||||||
current_val_avg_nll = total_val_nll / \
|
current_val_avg_nll = total_val_nll / \
|
||||||
@@ -432,6 +478,16 @@ class Trainer:
|
|||||||
|
|
||||||
val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0
|
val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0
|
||||||
val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0
|
val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0
|
||||||
|
val_log_scale_sq = (
|
||||||
|
total_val_log_scale_sq / total_val_pairs
|
||||||
|
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
val_log_shape_sq = (
|
||||||
|
total_val_log_shape_sq / total_val_pairs
|
||||||
|
if total_val_pairs > 0 and isinstance(self.criterion, WeibullNLLLoss)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
history.append({
|
history.append({
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
@@ -440,6 +496,10 @@ class Trainer:
|
|||||||
"val_nll": val_nll,
|
"val_nll": val_nll,
|
||||||
"val_reg": val_reg,
|
"val_reg": val_reg,
|
||||||
"delta_scale": delta_scale,
|
"delta_scale": delta_scale,
|
||||||
|
"train_log_scale_sq": train_log_scale_sq,
|
||||||
|
"train_log_shape_sq": train_log_shape_sq,
|
||||||
|
"val_log_scale_sq": val_log_scale_sq,
|
||||||
|
"val_log_shape_sq": val_log_shape_sq,
|
||||||
})
|
})
|
||||||
|
|
||||||
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
|
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
|
||||||
@@ -449,6 +509,16 @@ class Trainer:
|
|||||||
tqdm.write(f" Val Reg: {val_reg:.4f}")
|
tqdm.write(f" Val Reg: {val_reg:.4f}")
|
||||||
if delta_scale is not None:
|
if delta_scale is not None:
|
||||||
tqdm.write(f" Delta scale: {delta_scale:.6g}")
|
tqdm.write(f" Delta scale: {delta_scale:.6g}")
|
||||||
|
if train_log_scale_sq is not None and train_log_shape_sq is not None:
|
||||||
|
tqdm.write(
|
||||||
|
f" Train log(scale+eps)^2 mean: {train_log_scale_sq:.6g}")
|
||||||
|
tqdm.write(
|
||||||
|
f" Train log(shape+eps)^2 mean: {train_log_shape_sq:.6g}")
|
||||||
|
if val_log_scale_sq is not None and val_log_shape_sq is not None:
|
||||||
|
tqdm.write(
|
||||||
|
f" Val log(scale+eps)^2 mean: {val_log_scale_sq:.6g}")
|
||||||
|
tqdm.write(
|
||||||
|
f" Val log(shape+eps)^2 mean: {val_log_shape_sq:.6g}")
|
||||||
|
|
||||||
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
|
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
|
||||||
json.dump(history, f, indent=4)
|
json.dump(history, f, indent=4)
|
||||||
|
|||||||
Reference in New Issue
Block a user