fix(conv): ensure tensors are contiguous before Conv1d (post-pad and post-transpose) to avoid cuDNN ptrDesc->finalize
This commit is contained in:
@@ -17,6 +17,7 @@ class CausalConv1d(nn.Module):
|
|||||||
)
|
)
|
||||||
def forward(self, x): # x: (B, C, L)
|
def forward(self, x): # x: (B, C, L)
|
||||||
x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality
|
x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality
|
||||||
|
x = x.contiguous()
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
class DepthwiseSeparableCausalConvBlock(nn.Module):
|
class DepthwiseSeparableCausalConvBlock(nn.Module):
|
||||||
@@ -29,10 +30,10 @@ class DepthwiseSeparableCausalConvBlock(nn.Module):
|
|||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(self, x): # x: (B, L, D)
|
def forward(self, x): # x: (B, L, D)
|
||||||
y = x.transpose(1, 2) # (B, D, L)
|
y = x.transpose(1, 2).contiguous() # (B, D, L)
|
||||||
y = self.dw(y) # (B, D, L)
|
y = self.dw(y) # (B, D, L)
|
||||||
y = self.pw(y) # (B, D, L)
|
y = self.pw(y.contiguous()) # (B, D, L)
|
||||||
y = y.transpose(1, 2) # (B, L, D)
|
y = y.transpose(1, 2).contiguous() # (B, L, D)
|
||||||
y = self.act(y)
|
y = self.act(y)
|
||||||
y = self.dropout(y)
|
y = self.dropout(y)
|
||||||
return self.ln(x + y) # residual connection + layer norm (LN)
|
return self.ln(x + y) # residual connection + layer norm (LN)
|
||||||
|
|||||||
Reference in New Issue
Block a user