Commit d9c6ba40 authored by Phil Wang's avatar Phil Wang
Browse files

trims output on right for causal convtranspose1d, thanks to @NPN in...

trims output on right for causal convtranspose1d, thanks to @NPN in https://github.com/lucidrains/audiolm-pytorch/issues/8 !
parent 4acee04e
Loading
Loading
Loading
Loading
+17 −1
Original line number Diff line number Diff line
@@ -88,6 +88,22 @@ class CausalConv1d(nn.Module):
        x = F.pad(x, (self.causal_padding, 0))
        return self.conv(x)

class CausalConvTranspose1d(nn.Module):
    def __init__(self, chan_in, chan_out, kernel_size, stride, **kwargs):
        super().__init__()
        self.upsample_factor = stride
        self.padding = kernel_size - 1
        self.conv = nn.ConvTranspose1d(chan_in, chan_out, kernel_size, stride, **kwargs)

    def forward(self, x):
        n = x.shape[-1]

        x = F.pad(x, (self.padding, 0))
        out = self.conv(x)
        out = out[..., :(n * self.upsample_factor)]

        return out

def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7):
    return Residual(nn.Sequential(
        CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation),
@@ -110,7 +126,7 @@ def DecoderBlock(chan_in, chan_out, stride):
    output_padding = 0 if even_stride else 1

    return nn.Sequential(
        nn.ConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride, padding = padding, output_padding = output_padding),
        CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride),
        ResidualUnit(chan_out, chan_out, 1),
        ResidualUnit(chan_out, chan_out, 3),
        ResidualUnit(chan_out, chan_out, 9),