From 7e57e5d3b18dd1360b3728d2a0ae9d7a9c7f673f Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sat, 18 Oct 2025 15:21:10 +0800 Subject: [PATCH] refactor: Update survival loss calculation in CombinedLoss --- models.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/models.py b/models.py index 203e84f..2bd26bb 100644 --- a/models.py +++ b/models.py @@ -420,8 +420,12 @@ class CombinedLoss(nn.Module): per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none') loss_ce = per_element_ce[mask].mean() - intensity = torch.sum(torch.exp(logits), dim=2) - per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t) - loss_survival = per_element_survival[mask].mean() + # Survival loss based on exponential log-likelihood + t_min = 0.1 + lse = torch.logsumexp(logits, dim=-1) + lse = -torch.log(torch.exp(-lse) + t_min) + ldt = -torch.log(t + t_min) + loss_dt = -(lse - torch.exp(lse - ldt)) + loss_survival = loss_dt[mask].mean() return loss_ce, loss_survival