fix(conv): ensure tensors are contiguous before Conv1d (post-pad and post-transpose) to avoid cuDNN ptrDesc->finalize

This commit is contained in:
2025-10-22 17:54:04 +08:00
parent dd58ced9b9
commit 4d1fc63667

View File

@@ -17,6 +17,7 @@ class CausalConv1d(nn.Module):
)
def forward(self, x): # x: (B, C, L)
x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality
x = x.contiguous()
return self.conv(x)
class DepthwiseSeparableCausalConvBlock(nn.Module):
@@ -29,10 +30,10 @@ class DepthwiseSeparableCausalConvBlock(nn.Module):
self.dropout = nn.Dropout(dropout)
def forward(self, x): # x: (B, L, D)
y = x.transpose(1, 2) # (B, D, L)
y = x.transpose(1, 2).contiguous() # (B, D, L)
y = self.dw(y) # (B, D, L)
y = self.pw(y) # (B, D, L)
y = y.transpose(1, 2) # (B, L, D)
y = self.pw(y.contiguous()) # (B, D, L)
y = y.transpose(1, 2).contiguous() # (B, L, D)
y = self.act(y)
y = self.dropout(y)
return self.ln(x + y) # residual connection + layer norm (LN)