refactor: Update survival loss calculation in CombinedLoss
This commit is contained in:
10
models.py
10
models.py
@@ -420,8 +420,12 @@ class CombinedLoss(nn.Module):
|
|||||||
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
|
per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none')
|
||||||
loss_ce = per_element_ce[mask].mean()
|
loss_ce = per_element_ce[mask].mean()
|
||||||
|
|
||||||
intensity = torch.sum(torch.exp(logits), dim=2)
|
# Survival loss based on exponential log-likelihood
|
||||||
per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t)
|
t_min = 0.1
|
||||||
loss_survival = per_element_survival[mask].mean()
|
lse = torch.logsumexp(logits, dim=-1)
|
||||||
|
lse = -torch.log(torch.exp(-lse) + t_min)
|
||||||
|
ldt = -torch.log(t + t_min)
|
||||||
|
loss_dt = -(lse - torch.exp(lse - ldt))
|
||||||
|
loss_survival = loss_dt[mask].mean()
|
||||||
|
|
||||||
return loss_ce, loss_survival
|
return loss_ce, loss_survival
|
||||||
|
Reference in New Issue
Block a user