Loading audiolm_pytorch/soundstream.py +17 −15 Original line number Diff line number Diff line Loading @@ -305,17 +305,18 @@ class Residual(nn.Module): return self.fn(x, **kwargs) + x class CausalConv1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size, **kwargs): def __init__(self, chan_in, chan_out, kernel_size, pad_mode = 'reflect', **kwargs): super().__init__() kernel_size = kernel_size dilation = kwargs.get('dilation', 1) stride = kwargs.get('stride', 1) self.pad_mode = pad_mode self.causal_padding = dilation * (kernel_size - 1) + (1 - stride) self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, **kwargs) def forward(self, x): x = F.pad(x, (self.causal_padding, 0), mode = 'reflect') x = F.pad(x, (self.causal_padding, 0), mode = self.pad_mode) return self.conv(x) class CausalConvTranspose1d(nn.Module): Loading @@ -333,18 +334,18 @@ class CausalConvTranspose1d(nn.Module): return out def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False): def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False, pad_mode = 'reflect'): return Residual(Sequential( CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation), CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation, pad_mode = pad_mode), nn.ELU(), CausalConv1d(chan_out, chan_out, 1), CausalConv1d(chan_out, chan_out, 1, pad_mode = pad_mode), nn.ELU(), SqueezeExcite(chan_out) if squeeze_excite else None )) def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False): def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'): it = cycle(cycle_dilations) residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite) residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode) return nn.Sequential( residual_unit(chan_in, chan_in, next(it)), Loading @@ -353,12 +354,12 @@ def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride) ) def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False): def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'): even_stride = (stride % 2 == 0) padding = (stride + (0 if even_stride else 1)) // 2 output_padding = 0 if even_stride else 1 residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite) residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode) it = cycle(cycle_dilations) return nn.Sequential( Loading Loading @@ -450,6 +451,7 @@ class SoundStream(nn.Module): attn_dynamic_pos_bias = False, squeeze_excite = False, complex_stft_discr_logits_abs = True, pad_mode = 'reflect', stft_discriminator: Optional[nn.Module] = None # can pass in own stft discriminator ): super().__init__() Loading @@ -475,12 +477,12 @@ class SoundStream(nn.Module): encoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides): encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations, squeeze_excite)) encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations, squeeze_excite, pad_mode)) self.encoder = nn.Sequential( CausalConv1d(input_channels, channels, 7), CausalConv1d(input_channels, channels, 7, pad_mode = pad_mode), *encoder_blocks, CausalConv1d(layer_channels[-1], codebook_dim, 3) CausalConv1d(layer_channels[-1], codebook_dim, 3, pad_mode = pad_mode) ) attn_kwargs = dict( Loading Loading @@ -526,12 +528,12 @@ class SoundStream(nn.Module): decoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)): decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations, squeeze_excite)) decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations, squeeze_excite, pad_mode)) self.decoder = nn.Sequential( CausalConv1d(codebook_dim, layer_channels[-1], 7), CausalConv1d(codebook_dim, layer_channels[-1], 7, pad_mode = pad_mode), *decoder_blocks, CausalConv1d(channels, input_channels, 7) CausalConv1d(channels, input_channels, 7, pad_mode = pad_mode) ) # discriminators Loading audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.30.0' __version__ = '0.30.1' Loading
audiolm_pytorch/soundstream.py +17 −15 Original line number Diff line number Diff line Loading @@ -305,17 +305,18 @@ class Residual(nn.Module): return self.fn(x, **kwargs) + x class CausalConv1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size, **kwargs): def __init__(self, chan_in, chan_out, kernel_size, pad_mode = 'reflect', **kwargs): super().__init__() kernel_size = kernel_size dilation = kwargs.get('dilation', 1) stride = kwargs.get('stride', 1) self.pad_mode = pad_mode self.causal_padding = dilation * (kernel_size - 1) + (1 - stride) self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, **kwargs) def forward(self, x): x = F.pad(x, (self.causal_padding, 0), mode = 'reflect') x = F.pad(x, (self.causal_padding, 0), mode = self.pad_mode) return self.conv(x) class CausalConvTranspose1d(nn.Module): Loading @@ -333,18 +334,18 @@ class CausalConvTranspose1d(nn.Module): return out def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False): def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False, pad_mode = 'reflect'): return Residual(Sequential( CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation), CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation, pad_mode = pad_mode), nn.ELU(), CausalConv1d(chan_out, chan_out, 1), CausalConv1d(chan_out, chan_out, 1, pad_mode = pad_mode), nn.ELU(), SqueezeExcite(chan_out) if squeeze_excite else None )) def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False): def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'): it = cycle(cycle_dilations) residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite) residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode) return nn.Sequential( residual_unit(chan_in, chan_in, next(it)), Loading @@ -353,12 +354,12 @@ def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride) ) def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False): def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'): even_stride = (stride % 2 == 0) padding = (stride + (0 if even_stride else 1)) // 2 output_padding = 0 if even_stride else 1 residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite) residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode) it = cycle(cycle_dilations) return nn.Sequential( Loading Loading @@ -450,6 +451,7 @@ class SoundStream(nn.Module): attn_dynamic_pos_bias = False, squeeze_excite = False, complex_stft_discr_logits_abs = True, pad_mode = 'reflect', stft_discriminator: Optional[nn.Module] = None # can pass in own stft discriminator ): super().__init__() Loading @@ -475,12 +477,12 @@ class SoundStream(nn.Module): encoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides): encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations, squeeze_excite)) encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations, squeeze_excite, pad_mode)) self.encoder = nn.Sequential( CausalConv1d(input_channels, channels, 7), CausalConv1d(input_channels, channels, 7, pad_mode = pad_mode), *encoder_blocks, CausalConv1d(layer_channels[-1], codebook_dim, 3) CausalConv1d(layer_channels[-1], codebook_dim, 3, pad_mode = pad_mode) ) attn_kwargs = dict( Loading Loading @@ -526,12 +528,12 @@ class SoundStream(nn.Module): decoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)): decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations, squeeze_excite)) decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations, squeeze_excite, pad_mode)) self.decoder = nn.Sequential( CausalConv1d(codebook_dim, layer_channels[-1], 7), CausalConv1d(codebook_dim, layer_channels[-1], 7, pad_mode = pad_mode), *decoder_blocks, CausalConv1d(channels, input_channels, 7) CausalConv1d(channels, input_channels, 7, pad_mode = pad_mode) ) # discriminators Loading
audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.30.0' __version__ = '0.30.1'