Loading audiolm_pytorch/soundstream.py +53 −14 Original line number Diff line number Diff line Loading @@ -75,6 +75,11 @@ def gradient_penalty(wave, output, weight = 10): gradients = rearrange(gradients, 'b ... -> b (...)') return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean() # better sequential def Sequential(*mods): return nn.Sequential(*filter(exists, mods)) # discriminators class MultiScaleDiscriminator(nn.Module): Loading Loading @@ -124,6 +129,34 @@ class MultiScaleDiscriminator(nn.Module): return out, intermediates # autoregressive squeeze excitation class SqueezeExcite(nn.Module): def __init__(self, dim, reduction_factor = 4, dim_minimum = 8): super().__init__() dim_inner = max(dim_minimum, dim // reduction_factor) self.net = nn.Sequential( nn.Conv1d(dim, dim_inner, 1), nn.SiLU(), nn.Conv1d(dim_inner, dim, 1), nn.Sigmoid() ) def forward(self, x): seq, device = x.shape[-2], x.device # cumulative mean - since it is autoregressive cum_sum = x.cumsum(dim = -2) denom = torch.arange(1, seq + 1, device = device).float() cum_mean = cum_sum / rearrange(denom, 'n -> n 1') # glu gate gate = self.net(cum_mean) return x * gate # complex stft discriminator class ModReLU(nn.Module): Loading Loading @@ -284,34 +317,39 @@ class CausalConvTranspose1d(nn.Module): return out def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7): return Residual(nn.Sequential( def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False): return Residual(Sequential( CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation), nn.ELU(), CausalConv1d(chan_out, chan_out, 1), nn.ELU() nn.ELU(), SqueezeExcite(chan_out) if squeeze_excite else None )) def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9)): def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False): it = cycle(cycle_dilations) residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite) return nn.Sequential( ResidualUnit(chan_in, chan_in, next(it)), ResidualUnit(chan_in, chan_in, next(it)), ResidualUnit(chan_in, chan_in, next(it)), residual_unit(chan_in, chan_in, next(it)), residual_unit(chan_in, chan_in, next(it)), residual_unit(chan_in, chan_in, next(it)), CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride) ) def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9)): def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False): even_stride = (stride % 2 == 0) padding = (stride + (0 if even_stride else 1)) // 2 output_padding = 0 if even_stride else 1 residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite) it = cycle(cycle_dilations) return nn.Sequential( CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride), ResidualUnit(chan_out, chan_out, next(it)), ResidualUnit(chan_out, chan_out, next(it)), ResidualUnit(chan_out, chan_out, next(it)), residual_unit(chan_out, chan_out, next(it)), residual_unit(chan_out, chan_out, next(it)), residual_unit(chan_out, chan_out, next(it)), ) class LocalTransformer(nn.Module): Loading Loading @@ -383,7 +421,8 @@ class SoundStream(nn.Module): attn_heads = 8, attn_depth = 1, attn_xpos_scale_base = None, attn_dynamic_pos_bias = False attn_dynamic_pos_bias = False, squeeze_excite = False ): super().__init__() Loading @@ -408,7 +447,7 @@ class SoundStream(nn.Module): encoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides): encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations)) encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations, squeeze_excite)) self.encoder = nn.Sequential( CausalConv1d(input_channels, channels, 7), Loading Loading @@ -450,7 +489,7 @@ class SoundStream(nn.Module): decoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)): decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations)) decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations, squeeze_excite)) self.decoder = nn.Sequential( CausalConv1d(codebook_dim, layer_channels[-1], 7), Loading audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.22.0' __version__ = '0.22.1' Loading
audiolm_pytorch/soundstream.py +53 −14 Original line number Diff line number Diff line Loading @@ -75,6 +75,11 @@ def gradient_penalty(wave, output, weight = 10): gradients = rearrange(gradients, 'b ... -> b (...)') return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean() # better sequential def Sequential(*mods): return nn.Sequential(*filter(exists, mods)) # discriminators class MultiScaleDiscriminator(nn.Module): Loading Loading @@ -124,6 +129,34 @@ class MultiScaleDiscriminator(nn.Module): return out, intermediates # autoregressive squeeze excitation class SqueezeExcite(nn.Module): def __init__(self, dim, reduction_factor = 4, dim_minimum = 8): super().__init__() dim_inner = max(dim_minimum, dim // reduction_factor) self.net = nn.Sequential( nn.Conv1d(dim, dim_inner, 1), nn.SiLU(), nn.Conv1d(dim_inner, dim, 1), nn.Sigmoid() ) def forward(self, x): seq, device = x.shape[-2], x.device # cumulative mean - since it is autoregressive cum_sum = x.cumsum(dim = -2) denom = torch.arange(1, seq + 1, device = device).float() cum_mean = cum_sum / rearrange(denom, 'n -> n 1') # glu gate gate = self.net(cum_mean) return x * gate # complex stft discriminator class ModReLU(nn.Module): Loading Loading @@ -284,34 +317,39 @@ class CausalConvTranspose1d(nn.Module): return out def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7): return Residual(nn.Sequential( def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False): return Residual(Sequential( CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation), nn.ELU(), CausalConv1d(chan_out, chan_out, 1), nn.ELU() nn.ELU(), SqueezeExcite(chan_out) if squeeze_excite else None )) def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9)): def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False): it = cycle(cycle_dilations) residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite) return nn.Sequential( ResidualUnit(chan_in, chan_in, next(it)), ResidualUnit(chan_in, chan_in, next(it)), ResidualUnit(chan_in, chan_in, next(it)), residual_unit(chan_in, chan_in, next(it)), residual_unit(chan_in, chan_in, next(it)), residual_unit(chan_in, chan_in, next(it)), CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride) ) def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9)): def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False): even_stride = (stride % 2 == 0) padding = (stride + (0 if even_stride else 1)) // 2 output_padding = 0 if even_stride else 1 residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite) it = cycle(cycle_dilations) return nn.Sequential( CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride), ResidualUnit(chan_out, chan_out, next(it)), ResidualUnit(chan_out, chan_out, next(it)), ResidualUnit(chan_out, chan_out, next(it)), residual_unit(chan_out, chan_out, next(it)), residual_unit(chan_out, chan_out, next(it)), residual_unit(chan_out, chan_out, next(it)), ) class LocalTransformer(nn.Module): Loading Loading @@ -383,7 +421,8 @@ class SoundStream(nn.Module): attn_heads = 8, attn_depth = 1, attn_xpos_scale_base = None, attn_dynamic_pos_bias = False attn_dynamic_pos_bias = False, squeeze_excite = False ): super().__init__() Loading @@ -408,7 +447,7 @@ class SoundStream(nn.Module): encoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides): encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations)) encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations, squeeze_excite)) self.encoder = nn.Sequential( CausalConv1d(input_channels, channels, 7), Loading Loading @@ -450,7 +489,7 @@ class SoundStream(nn.Module): decoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)): decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations)) decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations, squeeze_excite)) self.decoder = nn.Sequential( CausalConv1d(codebook_dim, layer_channels[-1], 7), Loading
audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.22.0' __version__ = '0.22.1'