From f795aa5604275825819a851541e0cc14105255cb Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Sat, 10 Jan 2026 11:42:03 +0800 Subject: [PATCH] Fix survival calculation in cifs_from_exponential_logits: broadcast mask for compatibility with tensor shapes --- evaluate_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/evaluate_models.py b/evaluate_models.py index ac64607..810e77f 100644 --- a/evaluate_models.py +++ b/evaluate_models.py @@ -342,8 +342,9 @@ def cifs_from_exponential_logits( return cif survival = torch.exp(-total_h * taus_t).squeeze(1) # (B,H) - survival = torch.where(total.squeeze(1) > 0, survival, - torch.ones_like(survival)) + # Broadcast mask (B,) -> (B,1) for torch.where with (B,H) tensors. + nonzero = (total.squeeze(1) > 0).unsqueeze(1) + survival = torch.where(nonzero, survival, torch.ones_like(survival)) return cif, survival