Commit b1f4c6fd authored by Phil Wang's avatar Phil Wang
Browse files

add local attention at the innermost layer of soundstream, flanking the...

add local attention at the innermost layer of soundstream, flanking the residual quantization layer on both encoder and decoder
parent 1a6cfc30
Loading
Loading
Loading
Loading
+24 −1
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ import torch.nn.functional as F
from einops import rearrange, reduce

from vector_quantize_pytorch import ResidualVQ
from local_attention import LocalMHA

from audiolm_pytorch.utils import curtail_to_multiple

@@ -254,7 +255,10 @@ class SoundStream(nn.Module):
        adversarial_loss_weight = 1.,
        feature_loss_weight = 100,
        quantize_dropout_cutoff_index = 1,
        target_sample_hz = 24000
        target_sample_hz = 24000,
        attn_window_size = 128,
        attn_dim_head = 64,
        attn_heads = 8
    ):
        super().__init__()
        self.target_sample_hz = target_sample_hz # for resampling on the fly
@@ -277,6 +281,17 @@ class SoundStream(nn.Module):
            CausalConv1d(layer_channels[-1], codebook_dim, 3)
        )

        attn_kwargs = dict(
            dim = codebook_dim,
            dim_head = attn_dim_head,
            heads = attn_heads,
            window_size = attn_window_size,
            prenorm = True,
            causal = True
        )

        self.encoder_attn = LocalMHA(**attn_kwargs)

        self.rq = ResidualVQ(
            dim = codebook_dim,
            num_quantizers = rq_num_quantizers,
@@ -288,6 +303,8 @@ class SoundStream(nn.Module):
            quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
        )

        self.decoder_attn = LocalMHA(**attn_kwargs)

        decoder_blocks = []

        for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)):
@@ -315,6 +332,8 @@ class SoundStream(nn.Module):
    def decode_from_codebook_indices(self, quantized_indices):
        codes = self.rq.get_codes_from_indices(quantized_indices)
        x = reduce(codes, 'q ... -> ...', 'sum')

        x = self.decoder_attn(x) + x
        x = rearrange(x, 'b n c -> b c n')
        return self.decoder(x)

@@ -354,7 +373,11 @@ class SoundStream(nn.Module):
        x = self.encoder(x)

        x = rearrange(x, 'b c n -> b n c')
        x = self.encoder_attn(x) + x

        x, indices, commit_loss = self.rq(x)

        x = self.decoder_attn(x) + x
        x = rearrange(x, 'b n c -> b c n')

        if return_encoded:
+1 −0
Original line number Diff line number Diff line
@@ -24,6 +24,7 @@ setup(
    'ema-pytorch',
    'fairseq',
    'joblib',
    'local-attention>=1.5.7',
    'scikit-learn',
    'sentencepiece',
    'torch>=1.6',