refactor: Update survival loss calculation in CombinedLoss
This commit is contained in:
10
models.py
10
models.py
@@ -420,8 +420,12 @@ class CombinedLoss(nn.Module):
|
||||
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
|
||||
loss_ce = per_element_ce[mask].mean()
|
||||
|
||||
intensity = torch.sum(torch.exp(logits), dim=2)
|
||||
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t)
|
||||
loss_survival = per_element_survival[mask].mean()
|
||||
# Survival loss based on exponential log-likelihood
|
||||
t_min = 0.1
|
||||
lse = torch.logsumexp(logits, dim=-1)
|
||||
lse = -torch.log(torch.exp(-lse) + t_min)
|
||||
ldt = -torch.log(t + t_min)
|
||||
loss_dt = -(lse - torch.exp(lse - ldt))
|
||||
loss_survival = loss_dt[mask].mean()
|
||||
|
||||
return loss_ce, loss_survival
|
||||
|
Reference in New Issue
Block a user