Loading audiolm_pytorch/soundstream.py +32 −9 Original line number Diff line number Diff line Loading @@ -19,7 +19,7 @@ from einops import rearrange, reduce, pack, unpack from vector_quantize_pytorch import ResidualVQ from local_attention import LocalMHA from local_attention.transformer import FeedForward from local_attention.transformer import FeedForward, DynamicPositionBias from audiolm_pytorch.utils import curtail_to_multiple Loading Loading @@ -314,20 +314,40 @@ def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9)): ResidualUnit(chan_out, chan_out, next(it)), ) class LocalTransformerBlock(nn.Module): class LocalTransformer(nn.Module): def __init__( self, *, dim, depth, heads, window_size, dynamic_pos_bias = False, **kwargs ): super().__init__() self.attn = LocalMHA(dim = dim, qk_rmsnorm = True, use_xpos = True, **kwargs) self.ff = FeedForward(dim = dim) self.window_size = window_size self.layers = nn.ModuleList([]) self.pos_bias = None if dynamic_pos_bias: self.pos_bias = DynamicPositionBias(dim = dim // 2, heads = heads) for _ in range(depth): self.layers.append(nn.ModuleList([ LocalMHA(dim = dim, heads = heads, qk_rmsnorm = True, window_size = window_size, use_rotary_pos_emb = not dynamic_pos_bias, use_xpos = True, **kwargs), FeedForward(dim = dim) ])) def forward(self, x): x = self.attn(x) + x x = self.ff(x) + x w = self.window_size attn_bias = self.pos_bias(w, w * 2) if exists(self.pos_bias) else None for attn, ff in self.layers: x = attn(x, attn_bias = attn_bias) + x x = ff(x) + x return x class SoundStream(nn.Module): Loading Loading @@ -361,7 +381,8 @@ class SoundStream(nn.Module): attn_dim_head = 64, attn_heads = 8, attn_depth = 1, attn_xpos_scale_base = None attn_xpos_scale_base = None, attn_dynamic_pos_bias = False ): super().__init__() Loading Loading @@ -398,13 +419,15 @@ class SoundStream(nn.Module): dim = codebook_dim, dim_head = attn_dim_head, heads = attn_heads, depth = attn_depth, window_size = attn_window_size, xpos_scale_base = attn_xpos_scale_base, dynamic_pos_bias = attn_dynamic_pos_bias, prenorm = True, causal = True ) self.encoder_attn = nn.Sequential(*[LocalTransformerBlock(**attn_kwargs) for _ in range(attn_depth)]) if use_local_attn else None self.encoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None self.num_quantizers = rq_num_quantizers Loading @@ -420,7 +443,7 @@ class SoundStream(nn.Module): quantize_dropout_cutoff_index = quantize_dropout_cutoff_index ) self.decoder_attn = nn.Sequential(*[LocalTransformerBlock(**attn_kwargs) for _ in range(attn_depth)]) if use_local_attn else None self.decoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None decoder_blocks = [] Loading audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.20.0' __version__ = '0.21.0' setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -26,7 +26,7 @@ setup( 'fairseq', 'joblib', 'lion-pytorch', 'local-attention>=1.7.2', 'local-attention>=1.8.1', 'scikit-learn', 'sentencepiece', 'torch>=1.12', Loading Loading
audiolm_pytorch/soundstream.py +32 −9 Original line number Diff line number Diff line Loading @@ -19,7 +19,7 @@ from einops import rearrange, reduce, pack, unpack from vector_quantize_pytorch import ResidualVQ from local_attention import LocalMHA from local_attention.transformer import FeedForward from local_attention.transformer import FeedForward, DynamicPositionBias from audiolm_pytorch.utils import curtail_to_multiple Loading Loading @@ -314,20 +314,40 @@ def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9)): ResidualUnit(chan_out, chan_out, next(it)), ) class LocalTransformerBlock(nn.Module): class LocalTransformer(nn.Module): def __init__( self, *, dim, depth, heads, window_size, dynamic_pos_bias = False, **kwargs ): super().__init__() self.attn = LocalMHA(dim = dim, qk_rmsnorm = True, use_xpos = True, **kwargs) self.ff = FeedForward(dim = dim) self.window_size = window_size self.layers = nn.ModuleList([]) self.pos_bias = None if dynamic_pos_bias: self.pos_bias = DynamicPositionBias(dim = dim // 2, heads = heads) for _ in range(depth): self.layers.append(nn.ModuleList([ LocalMHA(dim = dim, heads = heads, qk_rmsnorm = True, window_size = window_size, use_rotary_pos_emb = not dynamic_pos_bias, use_xpos = True, **kwargs), FeedForward(dim = dim) ])) def forward(self, x): x = self.attn(x) + x x = self.ff(x) + x w = self.window_size attn_bias = self.pos_bias(w, w * 2) if exists(self.pos_bias) else None for attn, ff in self.layers: x = attn(x, attn_bias = attn_bias) + x x = ff(x) + x return x class SoundStream(nn.Module): Loading Loading @@ -361,7 +381,8 @@ class SoundStream(nn.Module): attn_dim_head = 64, attn_heads = 8, attn_depth = 1, attn_xpos_scale_base = None attn_xpos_scale_base = None, attn_dynamic_pos_bias = False ): super().__init__() Loading Loading @@ -398,13 +419,15 @@ class SoundStream(nn.Module): dim = codebook_dim, dim_head = attn_dim_head, heads = attn_heads, depth = attn_depth, window_size = attn_window_size, xpos_scale_base = attn_xpos_scale_base, dynamic_pos_bias = attn_dynamic_pos_bias, prenorm = True, causal = True ) self.encoder_attn = nn.Sequential(*[LocalTransformerBlock(**attn_kwargs) for _ in range(attn_depth)]) if use_local_attn else None self.encoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None self.num_quantizers = rq_num_quantizers Loading @@ -420,7 +443,7 @@ class SoundStream(nn.Module): quantize_dropout_cutoff_index = quantize_dropout_cutoff_index ) self.decoder_attn = nn.Sequential(*[LocalTransformerBlock(**attn_kwargs) for _ in range(attn_depth)]) if use_local_attn else None self.decoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None decoder_blocks = [] Loading
audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.20.0' __version__ = '0.21.0'
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -26,7 +26,7 @@ setup( 'fairseq', 'joblib', 'lion-pytorch', 'local-attention>=1.7.2', 'local-attention>=1.8.1', 'scikit-learn', 'sentencepiece', 'torch>=1.12', Loading