Commit 9d832142 authored by Phil Wang's avatar Phil Wang
Browse files

go with normal convtranspose1d, causal may be not possible?

parent c3ed8db5
Loading
Loading
Loading
Loading
+1 −19
Original line number Diff line number Diff line
@@ -35,24 +35,6 @@ class CausalConv1d(nn.Module):
        x = F.pad(x, (self.causal_padding, 0))
        return self.conv(x)

class CausalConvTranspose1d(nn.Module):
    """ unsure if this module is correct """

    def __init__(self, chan_in, chan_out, kernel_size, stride, **kwargs):
        super().__init__()
        self.upsample_factor = stride
        self.padding = kernel_size - 1
        self.conv = nn.ConvTranspose1d(chan_in, chan_out, kernel_size, stride, **kwargs)

    def forward(self, x):
        n = x.shape[-1]

        x = F.pad(x, (self.padding, 0))
        out = self.conv(x)
        out = out[..., :(n * self.upsample_factor)]

        return out

def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7):
    return Residual(nn.Sequential(
        CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation),
@@ -73,7 +55,7 @@ def DecoderBlock(chan_in, chan_out, stride):
    output_padding = 0 if even_stride else 1

    return nn.Sequential(
        CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride),
        nn.ConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride, padding = padding, output_padding = output_padding),
        ResidualUnit(chan_out, chan_out, 1),
        ResidualUnit(chan_out, chan_out, 3),
        ResidualUnit(chan_out, chan_out, 9),