Fix survival calculation in cifs_from_exponential_logits: broadcast mask for compatibility with tensor shapes
This commit is contained in:
@@ -342,8 +342,9 @@ def cifs_from_exponential_logits(
|
||||
return cif
|
||||
|
||||
survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H)
|
||||
survival = torch.where(total.squeeze(1) > 0, survival,
|
||||
torch.ones_like(survival))
|
||||
# Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors.
|
||||
nonzero = (total.squeeze(1) > 0).unsqueeze(1)
|
||||
survival = torch.where(nonzero, survival, torch.ones_like(survival))
|
||||
return cif, survival
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user