From dd58ced9b9404839895593cdf75199f626acc1b4 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Wed, 22 Oct 2025 17:51:12 +0800 Subject: [PATCH] fix(attn): convert boolean attention mask to additive float (-1e9) to avoid cudnn ptrDesc->finalize on some GPUs --- models.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/models.py b/models.py index be27593..2555d48 100644 --- a/models.py +++ b/models.py @@ -100,9 +100,11 @@ class Block(nn.Module): def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor: normed_x = self.ln_1(x) - attn_mask = ~custom_mask - attn_mask = attn_mask.repeat_interleave(self.n_head, dim=0) - + # Build an additive attention mask to avoid backend issues with boolean masks on some GPUs + # custom_mask: True means allowed, False means masked. We convert to 0 for allowed and -large for masked. + mask_bool = (~custom_mask).repeat_interleave(self.n_head, dim=0) # True where we want to mask + attn_mask = mask_bool.to(dtype=normed_x.dtype) * (-1e9) + attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False) x = x + self.resid_dropout(attn_output) x = x + self.mlpf(self.ln_2(x))