refactor: Update survival loss calculation in CombinedLoss

This commit is contained in:
2025-10-18 15:21:10 +08:00
parent 14865ac5b6
commit 7e57e5d3b1

View File

@@ -420,8 +420,12 @@ class CombinedLoss(nn.Module):
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
loss_ce = per_element_ce[mask].mean()
intensity = torch.sum(torch.exp(logits), dim=2)
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t)
loss_survival = per_element_survival[mask].mean()
# Survival loss based on exponential log-likelihood
t_min = 0.1
lse = torch.logsumexp(logits, dim=-1)
lse = -torch.log(torch.exp(-lse) + t_min)
ldt = -torch.log(t + t_min)
loss_dt = -(lse - torch.exp(lse - ldt))
loss_survival = loss_dt[mask].mean()
return loss_ce, loss_survival