Loading audiolm_pytorch/soundstream.py +24 −1 Original line number Diff line number Diff line Loading @@ -9,6 +9,7 @@ import torch.nn.functional as F from einops import rearrange, reduce from vector_quantize_pytorch import ResidualVQ from local_attention import LocalMHA from audiolm_pytorch.utils import curtail_to_multiple Loading Loading @@ -254,7 +255,10 @@ class SoundStream(nn.Module): adversarial_loss_weight = 1., feature_loss_weight = 100, quantize_dropout_cutoff_index = 1, target_sample_hz = 24000 target_sample_hz = 24000, attn_window_size = 128, attn_dim_head = 64, attn_heads = 8 ): super().__init__() self.target_sample_hz = target_sample_hz # for resampling on the fly Loading @@ -277,6 +281,17 @@ class SoundStream(nn.Module): CausalConv1d(layer_channels[-1], codebook_dim, 3) ) attn_kwargs = dict( dim = codebook_dim, dim_head = attn_dim_head, heads = attn_heads, window_size = attn_window_size, prenorm = True, causal = True ) self.encoder_attn = LocalMHA(**attn_kwargs) self.rq = ResidualVQ( dim = codebook_dim, num_quantizers = rq_num_quantizers, Loading @@ -288,6 +303,8 @@ class SoundStream(nn.Module): quantize_dropout_cutoff_index = quantize_dropout_cutoff_index ) self.decoder_attn = LocalMHA(**attn_kwargs) decoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)): Loading Loading @@ -315,6 +332,8 @@ class SoundStream(nn.Module): def decode_from_codebook_indices(self, quantized_indices): codes = self.rq.get_codes_from_indices(quantized_indices) x = reduce(codes, 'q ... -> ...', 'sum') x = self.decoder_attn(x) + x x = rearrange(x, 'b n c -> b c n') return self.decoder(x) Loading Loading @@ -354,7 +373,11 @@ class SoundStream(nn.Module): x = self.encoder(x) x = rearrange(x, 'b c n -> b n c') x = self.encoder_attn(x) + x x, indices, commit_loss = self.rq(x) x = self.decoder_attn(x) + x x = rearrange(x, 'b n c -> b c n') if return_encoded: Loading setup.py +1 −0 Original line number Diff line number Diff line Loading @@ -24,6 +24,7 @@ setup( 'ema-pytorch', 'fairseq', 'joblib', 'local-attention>=1.5.7', 'scikit-learn', 'sentencepiece', 'torch>=1.6', Loading Loading
audiolm_pytorch/soundstream.py +24 −1 Original line number Diff line number Diff line Loading @@ -9,6 +9,7 @@ import torch.nn.functional as F from einops import rearrange, reduce from vector_quantize_pytorch import ResidualVQ from local_attention import LocalMHA from audiolm_pytorch.utils import curtail_to_multiple Loading Loading @@ -254,7 +255,10 @@ class SoundStream(nn.Module): adversarial_loss_weight = 1., feature_loss_weight = 100, quantize_dropout_cutoff_index = 1, target_sample_hz = 24000 target_sample_hz = 24000, attn_window_size = 128, attn_dim_head = 64, attn_heads = 8 ): super().__init__() self.target_sample_hz = target_sample_hz # for resampling on the fly Loading @@ -277,6 +281,17 @@ class SoundStream(nn.Module): CausalConv1d(layer_channels[-1], codebook_dim, 3) ) attn_kwargs = dict( dim = codebook_dim, dim_head = attn_dim_head, heads = attn_heads, window_size = attn_window_size, prenorm = True, causal = True ) self.encoder_attn = LocalMHA(**attn_kwargs) self.rq = ResidualVQ( dim = codebook_dim, num_quantizers = rq_num_quantizers, Loading @@ -288,6 +303,8 @@ class SoundStream(nn.Module): quantize_dropout_cutoff_index = quantize_dropout_cutoff_index ) self.decoder_attn = LocalMHA(**attn_kwargs) decoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)): Loading Loading @@ -315,6 +332,8 @@ class SoundStream(nn.Module): def decode_from_codebook_indices(self, quantized_indices): codes = self.rq.get_codes_from_indices(quantized_indices) x = reduce(codes, 'q ... -> ...', 'sum') x = self.decoder_attn(x) + x x = rearrange(x, 'b n c -> b c n') return self.decoder(x) Loading Loading @@ -354,7 +373,11 @@ class SoundStream(nn.Module): x = self.encoder(x) x = rearrange(x, 'b c n -> b n c') x = self.encoder_attn(x) + x x, indices, commit_loss = self.rq(x) x = self.decoder_attn(x) + x x = rearrange(x, 'b n c -> b c n') if return_encoded: Loading
setup.py +1 −0 Original line number Diff line number Diff line Loading @@ -24,6 +24,7 @@ setup( 'ema-pytorch', 'fairseq', 'joblib', 'local-attention>=1.5.7', 'scikit-learn', 'sentencepiece', 'torch>=1.6', Loading