Commit a94e32f4 authored by Phil Wang's avatar Phil Wang
Browse files

remove multihead EMA based on feedback from @ilya16

parent 4daad9e7
Loading
Loading
Loading
Loading
+0 −14
Original line number Diff line number Diff line
@@ -21,7 +21,6 @@ from vector_quantize_pytorch import ResidualVQ
from local_attention import LocalMHA
from local_attention.transformer import FeedForward

from mega_pytorch import MultiHeadedEMA
from audiolm_pytorch.utils import curtail_to_multiple

from audiolm_pytorch.version import __version__
@@ -378,9 +377,6 @@ class SoundStream(nn.Module):
        quantize_dropout_cutoff_index = 1,
        target_sample_hz = 16000,
        use_local_attn = True,
        use_mhesa = True,
        mhesa_heads = 4,
        mhesa_dim_head = 32,
        attn_window_size = 128,
        attn_dim_head = 64,
        attn_heads = 8,
@@ -411,11 +407,6 @@ class SoundStream(nn.Module):
        for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides):
            encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations))

            if not use_mhesa:
                continue

            encoder_blocks.append(MultiHeadEMABlock(chan_out, dim_head = mhesa_dim_head, heads = mhesa_heads))

        self.encoder = nn.Sequential(
            CausalConv1d(input_channels, channels, 7),
            *encoder_blocks,
@@ -452,11 +443,6 @@ class SoundStream(nn.Module):
        for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)):
            decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations))

            if not use_mhesa:
                continue

            decoder_blocks.append(MultiHeadEMABlock(chan_in, dim_head = mhesa_dim_head, heads = mhesa_heads))

        self.decoder = nn.Sequential(
            CausalConv1d(codebook_dim, layer_channels[-1], 7),
            *decoder_blocks,
+1 −1
Original line number Diff line number Diff line
__version__ = '0.17.1'
__version__ = '0.18.0'
+0 −1
Original line number Diff line number Diff line
@@ -27,7 +27,6 @@ setup(
    'joblib',
    'lion-pytorch',
    'local-attention>=1.6.0',
    'Mega-pytorch',
    'scikit-learn',
    'sentencepiece',
    'torch>=1.12',