diff --git a/evaluate_models.py b/evaluate_models.py index ac64607..810e77f 100644 --- a/evaluate_models.py +++ b/evaluate_models.py @@ -342,8 +342,9 @@ def cifs_from_exponential_logits( return cif survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H) - survival = torch.where(total.squeeze(1) > 0, survival, - torch.ones_like(survival)) + # Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors. + nonzero = (total.squeeze(1) > 0).unsqueeze(1) + survival = torch.where(nonzero, survival, torch.ones_like(survival)) return cif, survival