Commit ed10de42 authored by Phil Wang's avatar Phil Wang
Browse files

add learned multi-headed exponential moving average as an option for soundstream

parent 455def25
Loading
Loading
Loading
Loading
+9 −1
Original line number Diff line number Diff line
@@ -389,3 +389,11 @@ sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_
    primaryClass = {cs.CV}
}
```

```bibtex
@inproceedings{Ma2022MegaMA,
    title   = {Mega: Moving Average Equipped Gated Attention},
    author  = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer},
    year    = {2022}
}
```
+34 −0
Original line number Diff line number Diff line
@@ -17,6 +17,7 @@ from vector_quantize_pytorch import ResidualVQ
from local_attention import LocalMHA
from local_attention.transformer import FeedForward

from mega_pytorch import MultiHeadedEMA
from audiolm_pytorch.utils import curtail_to_multiple

# helper functions
@@ -185,6 +186,26 @@ class STFTDiscriminator(nn.Module):

        return logits, intermediates

# learned EMA blocks

class MultiHeadEMABlock(nn.Module):
    def __init__(
        self,
        dim,
        **kwargs
    ):
        super().__init__()
        self.prenorm = nn.LayerNorm(dim)
        self.mhema = MultiHeadedEMA(dim = dim, **kwargs)

    def forward(self, x):
        residual = x.clone()
        x = rearrange(x, 'b c n -> b n c')
        x = self.prenorm(x)
        x = self.mhema(x)
        x = rearrange(x, 'b n c -> b c n')
        return x + residual

# sound stream

class Residual(nn.Module):
@@ -293,6 +314,9 @@ class SoundStream(nn.Module):
        quantize_dropout_cutoff_index = 1,
        target_sample_hz = 24000,
        use_local_attn = True,
        use_mhesa = True,
        mhesa_heads = 4,
        mhesa_dim_head = 32,
        attn_window_size = 128,
        attn_dim_head = 64,
        attn_heads = 8
@@ -312,6 +336,11 @@ class SoundStream(nn.Module):
        for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides):
            encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations))

            if not use_mhesa:
                continue

            encoder_blocks.append(MultiHeadEMABlock(chan_out, dim_head = mhesa_dim_head, heads = mhesa_heads))

        self.encoder = nn.Sequential(
            CausalConv1d(input_channels, channels, 7),
            *encoder_blocks,
@@ -348,6 +377,11 @@ class SoundStream(nn.Module):
        for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)):
            decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations))

            if not use_mhesa:
                continue

            decoder_blocks.append(MultiHeadEMABlock(chan_in, dim_head = mhesa_dim_head, heads = mhesa_heads))

        self.decoder = nn.Sequential(
            CausalConv1d(codebook_dim, layer_channels[-1], 7),
            *decoder_blocks,
+3 −2
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.6.3',
  version = '0.7.1',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',
@@ -25,13 +25,14 @@ setup(
    'fairseq',
    'joblib',
    'local-attention>=1.5.7',
    'Mega-pytorch',
    'scikit-learn',
    'sentencepiece',
    'torch>=1.6',
    'torchaudio',
    'transformers',
    'tqdm',
    'vector-quantize-pytorch>=0.10.14'
    'vector-quantize-pytorch>=0.10.15'
  ],
  classifiers=[
    'Development Status :: 4 - Beta',