diff --git a/models.py b/models.py index 2555d48..a389725 100644 --- a/models.py +++ b/models.py @@ -17,6 +17,7 @@ class CausalConv1d(nn.Module): ) def forward(self, x): # x: (B, C, L) x = F.pad(x, (self.pad, 0)) # pad only on the left to ensure causality + x = x.contiguous() return self.conv(x) class DepthwiseSeparableCausalConvBlock(nn.Module): @@ -29,10 +30,10 @@ class DepthwiseSeparableCausalConvBlock(nn.Module): self.dropout = nn.Dropout(dropout) def forward(self, x): # x: (B, L, D) - y = x.transpose(1, 2) # (B, D, L) - y = self.dw(y) # (B, D, L) - y = self.pw(y) # (B, D, L) - y = y.transpose(1, 2) # (B, L, D) + y = x.transpose(1, 2).contiguous() # (B, D, L) + y = self.dw(y) # (B, D, L) + y = self.pw(y.contiguous()) # (B, D, L) + y = y.transpose(1, 2).contiguous() # (B, L, D) y = self.act(y) y = self.dropout(y) return self.ln(x + y) # residual connection + layer norm (LN)