Fix survival calculation in cifs_from_exponential_logits: broadcast mask for compatibility with tensor shapes

This commit is contained in:
2026-01-10 11:42:03 +08:00
parent f231a2e4e5
commit f795aa5604

View File

@@ -342,8 +342,9 @@ def cifs_from_exponential_logits(
return cif
survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H)
survival = torch.where(total.squeeze(1) > 0, survival,
torch.ones_like(survival))
# Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors.
nonzero = (total.squeeze(1) > 0).unsqueeze(1)
survival = torch.where(nonzero, survival, torch.ones_like(survival))
return cif, survival