fix(attn): convert boolean attention mask to additive float (-1e9) to avoid cudnn ptrDesc->finalize on some GPUs
This commit is contained in:
@@ -100,8 +100,10 @@ class Block(nn.Module):
|
||||
def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor:
|
||||
normed_x = self.ln_1(x)
|
||||
|
||||
attn_mask = ~custom_mask
|
||||
attn_mask = attn_mask.repeat_interleave(self.n_head, dim=0)
|
||||
# Build an additive attention mask to avoid backend issues with boolean masks on some GPUs
|
||||
# custom_mask: True means allowed, False means masked. We convert to 0 for allowed and -large for masked.
|
||||
mask_bool = (~custom_mask).repeat_interleave(self.n_head, dim=0) # True where we want to mask
|
||||
attn_mask = mask_bool.to(dtype=normed_x.dtype) * (-1e9)
|
||||
|
||||
attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False)
|
||||
x = x + self.resid_dropout(attn_output)
|
||||
|
||||
Reference in New Issue
Block a user