Commit ec642cd9 authored by Phil Wang's avatar Phil Wang
Browse files

make sure local attention blocks are followed by feedforward

parent 5a7fee75
Loading
Loading
Loading
Loading
+22 −4
Original line number Diff line number Diff line
@@ -12,7 +12,9 @@ import torchaudio.transforms as T
from einops import rearrange, reduce

from vector_quantize_pytorch import ResidualVQ

from local_attention import LocalMHA
from local_attention.transformer import FeedForward

from audiolm_pytorch.utils import curtail_to_multiple

@@ -248,6 +250,22 @@ def DecoderBlock(chan_in, chan_out, stride):
        ResidualUnit(chan_out, chan_out, 9),
    )

class LocalTransformerBlock(nn.Module):
    def __init__(
        self,
        *,
        dim,
        **kwargs
    ):
        super().__init__()
        self.attn = LocalMHA(dim = dim, **kwargs)
        self.ff = FeedForward(dim = dim)

    def forward(self, x):
        x = self.attn(x) + x
        x = self.ff(x) + x
        return x

class SoundStream(nn.Module):
    def __init__(
        self,
@@ -302,7 +320,7 @@ class SoundStream(nn.Module):
            causal = True
        )

        self.encoder_attn = LocalMHA(**attn_kwargs) if use_local_attn else None
        self.encoder_attn = LocalTransformerBlock(**attn_kwargs) if use_local_attn else None

        self.rq = ResidualVQ(
            dim = codebook_dim,
@@ -316,7 +334,7 @@ class SoundStream(nn.Module):
            quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
        )

        self.decoder_attn = LocalMHA(**attn_kwargs) if use_local_attn else None
        self.decoder_attn = LocalTransformerBlock(**attn_kwargs) if use_local_attn else None

        decoder_blocks = []

@@ -393,12 +411,12 @@ class SoundStream(nn.Module):
        x = rearrange(x, 'b c n -> b n c')

        if exists(self.encoder_attn):
            x = self.encoder_attn(x) + x
            x = self.encoder_attn(x)

        x, indices, commit_loss = self.rq(x)

        if exists(self.decoder_attn):
            x = self.decoder_attn(x) + x
            x = self.decoder_attn(x)

        x = rearrange(x, 'b n c -> b c n')

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.4.8',
  version = '0.5.0',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',