diff --git a/models.py b/models.py index 203e84f..2bd26bb 100644 --- a/models.py +++ b/models.py @@ -420,8 +420,12 @@ class CombinedLoss(nn.Module): per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none') loss_ce = per_element_ce[mask].mean() - intensity = torch.sum(torch.exp(logits), dim=2) - per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t) - loss_survival = per_element_survival[mask].mean() + # Survival loss based on exponential log-likelihood + t_min = 0.1 + lse = torch.logsumexp(logits, dim=-1) + lse = -torch.log(torch.exp(-lse) + t_min) + ldt = -torch.log(t + t_min) + loss_dt = -(lse - torch.exp(lse - ldt)) + loss_survival = loss_dt[mask].mean() return loss_ce, loss_survival