Commit c3ed8db5 authored by Phil Wang's avatar Phil Wang
Browse files

give causal convtranspose another shot

parent 3bdca366
Loading
Loading
Loading
Loading
+7 −2
Original line number Diff line number Diff line
@@ -40,12 +40,17 @@ class CausalConvTranspose1d(nn.Module):

    def __init__(self, chan_in, chan_out, kernel_size, stride, **kwargs):
        super().__init__()
        self.neg_padding = kernel_size // 2
        self.upsample_factor = stride
        self.padding = kernel_size - 1
        self.conv = nn.ConvTranspose1d(chan_in, chan_out, kernel_size, stride, **kwargs)

    def forward(self, x):
        n = x.shape[-1]

        x = F.pad(x, (self.padding, 0))
        out = self.conv(x)
        out = out[..., :-self.neg_padding]
        out = out[..., :(n * self.upsample_factor)]

        return out

def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7):