fix(attn): convert boolean attention mask to additive float (-1e9) to avoid cudnn ptrDesc->finalize on some GPUs

This commit is contained in:
2025-10-22 17:51:12 +08:00
parent 3bef72f50b
commit dd58ced9b9

View File

@@ -100,9 +100,11 @@ class Block(nn.Module):
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
normed_x = self.ln_1(x) normed_x = self.ln_1(x)
attn_mask = ~custom_mask # Build an additive attention mask to avoid backend issues with boolean masks on some GPUs
attn_mask = attn_mask.repeat_interleave(self.n_head, dim=0) # custom_mask: True means allowed, False means masked. We convert to 0 for allowed and -large for masked.
mask_bool = (~custom_mask).repeat_interleave(self.n_head, dim=0) # True where we want to mask
attn_mask = mask_bool.to(dtype=normed_x.dtype) * (-1e9)
attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False) attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False)
x = x + self.resid_dropout(attn_output) x = x + self.resid_dropout(attn_output)
x = x + self.mlpf(self.ln_2(x)) x = x + self.mlpf(self.ln_2(x))