Loading README.md +9 −1 Original line number Diff line number Diff line Loading @@ -389,3 +389,11 @@ sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_ primaryClass = {cs.CV} } ``` ```bibtex @inproceedings{Ma2022MegaMA, title = {Mega: Moving Average Equipped Gated Attention}, author = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer}, year = {2022} } ``` audiolm_pytorch/soundstream.py +34 −0 Original line number Diff line number Diff line Loading @@ -17,6 +17,7 @@ from vector_quantize_pytorch import ResidualVQ from local_attention import LocalMHA from local_attention.transformer import FeedForward from mega_pytorch import MultiHeadedEMA from audiolm_pytorch.utils import curtail_to_multiple # helper functions Loading Loading @@ -185,6 +186,26 @@ class STFTDiscriminator(nn.Module): return logits, intermediates # learned EMA blocks class MultiHeadEMABlock(nn.Module): def __init__( self, dim, **kwargs ): super().__init__() self.prenorm = nn.LayerNorm(dim) self.mhema = MultiHeadedEMA(dim = dim, **kwargs) def forward(self, x): residual = x.clone() x = rearrange(x, 'b c n -> b n c') x = self.prenorm(x) x = self.mhema(x) x = rearrange(x, 'b n c -> b c n') return x + residual # sound stream class Residual(nn.Module): Loading Loading @@ -293,6 +314,9 @@ class SoundStream(nn.Module): quantize_dropout_cutoff_index = 1, target_sample_hz = 24000, use_local_attn = True, use_mhesa = True, mhesa_heads = 4, mhesa_dim_head = 32, attn_window_size = 128, attn_dim_head = 64, attn_heads = 8 Loading @@ -312,6 +336,11 @@ class SoundStream(nn.Module): for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides): encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations)) if not use_mhesa: continue encoder_blocks.append(MultiHeadEMABlock(chan_out, dim_head = mhesa_dim_head, heads = mhesa_heads)) self.encoder = nn.Sequential( CausalConv1d(input_channels, channels, 7), *encoder_blocks, Loading Loading @@ -348,6 +377,11 @@ class SoundStream(nn.Module): for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)): decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations)) if not use_mhesa: continue decoder_blocks.append(MultiHeadEMABlock(chan_in, dim_head = mhesa_dim_head, heads = mhesa_heads)) self.decoder = nn.Sequential( CausalConv1d(codebook_dim, layer_channels[-1], 7), *decoder_blocks, Loading setup.py +3 −2 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.6.3', version = '0.7.1', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading @@ -25,13 +25,14 @@ setup( 'fairseq', 'joblib', 'local-attention>=1.5.7', 'Mega-pytorch', 'scikit-learn', 'sentencepiece', 'torch>=1.6', 'torchaudio', 'transformers', 'tqdm', 'vector-quantize-pytorch>=0.10.14' 'vector-quantize-pytorch>=0.10.15' ], classifiers=[ 'Development Status :: 4 - Beta', Loading Loading
README.md +9 −1 Original line number Diff line number Diff line Loading @@ -389,3 +389,11 @@ sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_ primaryClass = {cs.CV} } ``` ```bibtex @inproceedings{Ma2022MegaMA, title = {Mega: Moving Average Equipped Gated Attention}, author = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer}, year = {2022} } ```
audiolm_pytorch/soundstream.py +34 −0 Original line number Diff line number Diff line Loading @@ -17,6 +17,7 @@ from vector_quantize_pytorch import ResidualVQ from local_attention import LocalMHA from local_attention.transformer import FeedForward from mega_pytorch import MultiHeadedEMA from audiolm_pytorch.utils import curtail_to_multiple # helper functions Loading Loading @@ -185,6 +186,26 @@ class STFTDiscriminator(nn.Module): return logits, intermediates # learned EMA blocks class MultiHeadEMABlock(nn.Module): def __init__( self, dim, **kwargs ): super().__init__() self.prenorm = nn.LayerNorm(dim) self.mhema = MultiHeadedEMA(dim = dim, **kwargs) def forward(self, x): residual = x.clone() x = rearrange(x, 'b c n -> b n c') x = self.prenorm(x) x = self.mhema(x) x = rearrange(x, 'b n c -> b c n') return x + residual # sound stream class Residual(nn.Module): Loading Loading @@ -293,6 +314,9 @@ class SoundStream(nn.Module): quantize_dropout_cutoff_index = 1, target_sample_hz = 24000, use_local_attn = True, use_mhesa = True, mhesa_heads = 4, mhesa_dim_head = 32, attn_window_size = 128, attn_dim_head = 64, attn_heads = 8 Loading @@ -312,6 +336,11 @@ class SoundStream(nn.Module): for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides): encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations)) if not use_mhesa: continue encoder_blocks.append(MultiHeadEMABlock(chan_out, dim_head = mhesa_dim_head, heads = mhesa_heads)) self.encoder = nn.Sequential( CausalConv1d(input_channels, channels, 7), *encoder_blocks, Loading Loading @@ -348,6 +377,11 @@ class SoundStream(nn.Module): for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)): decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations)) if not use_mhesa: continue decoder_blocks.append(MultiHeadEMABlock(chan_in, dim_head = mhesa_dim_head, heads = mhesa_heads)) self.decoder = nn.Sequential( CausalConv1d(codebook_dim, layer_channels[-1], 7), *decoder_blocks, Loading
setup.py +3 −2 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.6.3', version = '0.7.1', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading @@ -25,13 +25,14 @@ setup( 'fairseq', 'joblib', 'local-attention>=1.5.7', 'Mega-pytorch', 'scikit-learn', 'sentencepiece', 'torch>=1.6', 'torchaudio', 'transformers', 'tqdm', 'vector-quantize-pytorch>=0.10.14' 'vector-quantize-pytorch>=0.10.15' ], classifiers=[ 'Development Status :: 4 - Beta', Loading