Commit f54e19d5 authored by Phil Wang's avatar Phil Wang
Browse files

introduce attn_dynamic_pos_bias for soundstream, which should have the best...

introduce attn_dynamic_pos_bias for soundstream, which should have the best length extrapolation properties for attention
parent d4680bed
Loading
Loading
Loading
Loading
+32 −9
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ from einops import rearrange, reduce, pack, unpack
from vector_quantize_pytorch import ResidualVQ

from local_attention import LocalMHA
from local_attention.transformer import FeedForward
from local_attention.transformer import FeedForward, DynamicPositionBias

from audiolm_pytorch.utils import curtail_to_multiple

@@ -314,20 +314,40 @@ def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9)):
        ResidualUnit(chan_out, chan_out, next(it)),
    )

class LocalTransformerBlock(nn.Module):
class LocalTransformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        heads,
        window_size,
        dynamic_pos_bias = False,
        **kwargs
    ):
        super().__init__()
        self.attn = LocalMHA(dim = dim, qk_rmsnorm = True, use_xpos = True, **kwargs)
        self.ff = FeedForward(dim = dim)
        self.window_size = window_size
        self.layers = nn.ModuleList([])

        self.pos_bias = None
        if dynamic_pos_bias:
            self.pos_bias = DynamicPositionBias(dim = dim // 2, heads = heads)

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                LocalMHA(dim = dim, heads = heads, qk_rmsnorm = True, window_size = window_size, use_rotary_pos_emb = not dynamic_pos_bias, use_xpos = True, **kwargs),
                FeedForward(dim = dim)
            ]))

    def forward(self, x):
        x = self.attn(x) + x
        x = self.ff(x) + x
        w = self.window_size

        attn_bias = self.pos_bias(w, w * 2) if exists(self.pos_bias) else None

        for attn, ff in self.layers:
            x = attn(x, attn_bias = attn_bias) + x
            x = ff(x) + x

        return x

class SoundStream(nn.Module):
@@ -361,7 +381,8 @@ class SoundStream(nn.Module):
        attn_dim_head = 64,
        attn_heads = 8,
        attn_depth = 1,
        attn_xpos_scale_base = None
        attn_xpos_scale_base = None,
        attn_dynamic_pos_bias = False
    ):
        super().__init__()

@@ -398,13 +419,15 @@ class SoundStream(nn.Module):
            dim = codebook_dim,
            dim_head = attn_dim_head,
            heads = attn_heads,
            depth = attn_depth,
            window_size = attn_window_size,
            xpos_scale_base = attn_xpos_scale_base,
            dynamic_pos_bias = attn_dynamic_pos_bias,
            prenorm = True,
            causal = True
        )

        self.encoder_attn = nn.Sequential(*[LocalTransformerBlock(**attn_kwargs) for _ in range(attn_depth)]) if use_local_attn else None
        self.encoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None

        self.num_quantizers = rq_num_quantizers

@@ -420,7 +443,7 @@ class SoundStream(nn.Module):
            quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
        )

        self.decoder_attn = nn.Sequential(*[LocalTransformerBlock(**attn_kwargs) for _ in range(attn_depth)]) if use_local_attn else None
        self.decoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None

        decoder_blocks = []

+1 −1
Original line number Diff line number Diff line
__version__ = '0.20.0'
__version__ = '0.21.0'
+1 −1
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ setup(
    'fairseq',
    'joblib',
    'lion-pytorch',
    'local-attention>=1.7.2',
    'local-attention>=1.8.1',
    'scikit-learn',
    'sentencepiece',
    'torch>=1.12',