Fix survival calculation in cifs_from_exponential_logits: broadcast mask for compatibility with tensor shapes

This commit is contained in:
2026-01-10 11:42:03 +08:00
parent f231a2e4e5
commit f795aa5604

View File

@@ -342,8 +342,9 @@ def cifs_from_exponential_logits(
return cif return cif
survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H) survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H)
survival = torch.where(total.squeeze(1) > 0, survival, # Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors.
torch.ones_like(survival)) nonzero = (total.squeeze(1) > 0).unsqueeze(1)
survival = torch.where(nonzero, survival, torch.ones_like(survival))
return cif, survival return cif, survival