Loading audiolm_pytorch/audiolm_pytorch.py +1 −19 Original line number Diff line number Diff line Loading @@ -35,24 +35,6 @@ class CausalConv1d(nn.Module): x = F.pad(x, (self.causal_padding, 0)) return self.conv(x) class CausalConvTranspose1d(nn.Module): """ unsure if this module is correct """ def __init__(self, chan_in, chan_out, kernel_size, stride, **kwargs): super().__init__() self.upsample_factor = stride self.padding = kernel_size - 1 self.conv = nn.ConvTranspose1d(chan_in, chan_out, kernel_size, stride, **kwargs) def forward(self, x): n = x.shape[-1] x = F.pad(x, (self.padding, 0)) out = self.conv(x) out = out[..., :(n * self.upsample_factor)] return out def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7): return Residual(nn.Sequential( CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation), Loading @@ -73,7 +55,7 @@ def DecoderBlock(chan_in, chan_out, stride): output_padding = 0 if even_stride else 1 return nn.Sequential( CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride), nn.ConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride, padding = padding, output_padding = output_padding), ResidualUnit(chan_out, chan_out, 1), ResidualUnit(chan_out, chan_out, 3), ResidualUnit(chan_out, chan_out, 9), Loading Loading
audiolm_pytorch/audiolm_pytorch.py +1 −19 Original line number Diff line number Diff line Loading @@ -35,24 +35,6 @@ class CausalConv1d(nn.Module): x = F.pad(x, (self.causal_padding, 0)) return self.conv(x) class CausalConvTranspose1d(nn.Module): """ unsure if this module is correct """ def __init__(self, chan_in, chan_out, kernel_size, stride, **kwargs): super().__init__() self.upsample_factor = stride self.padding = kernel_size - 1 self.conv = nn.ConvTranspose1d(chan_in, chan_out, kernel_size, stride, **kwargs) def forward(self, x): n = x.shape[-1] x = F.pad(x, (self.padding, 0)) out = self.conv(x) out = out[..., :(n * self.upsample_factor)] return out def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7): return Residual(nn.Sequential( CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation), Loading @@ -73,7 +55,7 @@ def DecoderBlock(chan_in, chan_out, stride): output_padding = 0 if even_stride else 1 return nn.Sequential( CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride), nn.ConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride, padding = padding, output_padding = output_padding), ResidualUnit(chan_out, chan_out, 1), ResidualUnit(chan_out, chan_out, 3), ResidualUnit(chan_out, chan_out, 9), Loading