Fix survival calculation in cifs_from_exponential_logits: broadcast mask for compatibility with tensor shapes
This commit is contained in:
@@ -342,8 +342,9 @@ def cifs_from_exponential_logits(
|
|||||||
return cif
|
return cif
|
||||||
|
|
||||||
survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H)
|
survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H)
|
||||||
survival = torch.where(total.squeeze(1) > 0, survival,
|
# Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors.
|
||||||
torch.ones_like(survival))
|
nonzero = (total.squeeze(1) > 0).unsqueeze(1)
|
||||||
|
survival = torch.where(nonzero, survival, torch.ones_like(survival))
|
||||||
return cif, survival
|
return cif, survival
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user