diff --git a/models.py b/models.py index be27593..2555d48 100644 --- a/models.py +++ b/models.py @@ -100,9 +100,11 @@ class Block(nn.Module): def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor: normed_x = self.ln_1(x) - attn_mask = ~custom_mask - attn_mask = attn_mask.repeat_interleave(self.n_head, dim=0) - + # Build an additive attention mask to avoid backend issues with boolean masks on some GPUs + # custom_mask: True means allowed, False means masked. We convert to 0 for allowed and -large for masked. + mask_bool = (~custom_mask).repeat_interleave(self.n_head, dim=0) # True where we want to mask + attn_mask = mask_bool.to(dtype=normed_x.dtype) * (-1e9) + attn_output, _ = self.attn(normed_x, normed_x, normed_x, attn_mask=attn_mask, need_weights=False) x = x + self.resid_dropout(attn_output) x = x + self.mlpf(self.ln_2(x))