Commit 6a81c3a3 authored by Phil Wang's avatar Phil Wang
Browse files

customize depth of local attention transformer blocks in soundstream

parent 513da28c
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -349,7 +349,8 @@ class SoundStream(nn.Module):
        mhesa_dim_head = 32,
        attn_window_size = 128,
        attn_dim_head = 64,
        attn_heads = 8
        attn_heads = 8,
        attn_depth = 1,
    ):
        super().__init__()
        self.target_sample_hz = target_sample_hz # for resampling on the fly
@@ -386,7 +387,7 @@ class SoundStream(nn.Module):
            causal = True
        )

        self.encoder_attn = LocalTransformerBlock(**attn_kwargs) if use_local_attn else None
        self.encoder_attn = nn.Sequential(*[LocalTransformerBlock(**attn_kwargs) for _ in range(attn_depth)]) if use_local_attn else None

        self.rq = ResidualVQ(
            dim = codebook_dim,
@@ -400,7 +401,7 @@ class SoundStream(nn.Module):
            quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
        )

        self.decoder_attn = LocalTransformerBlock(**attn_kwargs) if use_local_attn else None
        self.decoder_attn = nn.Sequential(*[LocalTransformerBlock(**attn_kwargs) for _ in range(attn_depth)]) if use_local_attn else None

        decoder_blocks = []

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.9.7',
  version = '0.10.0',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',