Loading audiolm_pytorch/soundstream.py +22 −4 Original line number Diff line number Diff line Loading @@ -12,7 +12,9 @@ import torchaudio.transforms as T from einops import rearrange, reduce from vector_quantize_pytorch import ResidualVQ from local_attention import LocalMHA from local_attention.transformer import FeedForward from audiolm_pytorch.utils import curtail_to_multiple Loading Loading @@ -248,6 +250,22 @@ def DecoderBlock(chan_in, chan_out, stride): ResidualUnit(chan_out, chan_out, 9), ) class LocalTransformerBlock(nn.Module): def __init__( self, *, dim, **kwargs ): super().__init__() self.attn = LocalMHA(dim = dim, **kwargs) self.ff = FeedForward(dim = dim) def forward(self, x): x = self.attn(x) + x x = self.ff(x) + x return x class SoundStream(nn.Module): def __init__( self, Loading Loading @@ -302,7 +320,7 @@ class SoundStream(nn.Module): causal = True ) self.encoder_attn = LocalMHA(**attn_kwargs) if use_local_attn else None self.encoder_attn = LocalTransformerBlock(**attn_kwargs) if use_local_attn else None self.rq = ResidualVQ( dim = codebook_dim, Loading @@ -316,7 +334,7 @@ class SoundStream(nn.Module): quantize_dropout_cutoff_index = quantize_dropout_cutoff_index ) self.decoder_attn = LocalMHA(**attn_kwargs) if use_local_attn else None self.decoder_attn = LocalTransformerBlock(**attn_kwargs) if use_local_attn else None decoder_blocks = [] Loading Loading @@ -393,12 +411,12 @@ class SoundStream(nn.Module): x = rearrange(x, 'b c n -> b n c') if exists(self.encoder_attn): x = self.encoder_attn(x) + x x = self.encoder_attn(x) x, indices, commit_loss = self.rq(x) if exists(self.decoder_attn): x = self.decoder_attn(x) + x x = self.decoder_attn(x) x = rearrange(x, 'b n c -> b c n') Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.4.8', version = '0.5.0', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/soundstream.py +22 −4 Original line number Diff line number Diff line Loading @@ -12,7 +12,9 @@ import torchaudio.transforms as T from einops import rearrange, reduce from vector_quantize_pytorch import ResidualVQ from local_attention import LocalMHA from local_attention.transformer import FeedForward from audiolm_pytorch.utils import curtail_to_multiple Loading Loading @@ -248,6 +250,22 @@ def DecoderBlock(chan_in, chan_out, stride): ResidualUnit(chan_out, chan_out, 9), ) class LocalTransformerBlock(nn.Module): def __init__( self, *, dim, **kwargs ): super().__init__() self.attn = LocalMHA(dim = dim, **kwargs) self.ff = FeedForward(dim = dim) def forward(self, x): x = self.attn(x) + x x = self.ff(x) + x return x class SoundStream(nn.Module): def __init__( self, Loading Loading @@ -302,7 +320,7 @@ class SoundStream(nn.Module): causal = True ) self.encoder_attn = LocalMHA(**attn_kwargs) if use_local_attn else None self.encoder_attn = LocalTransformerBlock(**attn_kwargs) if use_local_attn else None self.rq = ResidualVQ( dim = codebook_dim, Loading @@ -316,7 +334,7 @@ class SoundStream(nn.Module): quantize_dropout_cutoff_index = quantize_dropout_cutoff_index ) self.decoder_attn = LocalMHA(**attn_kwargs) if use_local_attn else None self.decoder_attn = LocalTransformerBlock(**attn_kwargs) if use_local_attn else None decoder_blocks = [] Loading Loading @@ -393,12 +411,12 @@ class SoundStream(nn.Module): x = rearrange(x, 'b c n -> b n c') if exists(self.encoder_attn): x = self.encoder_attn(x) + x x = self.encoder_attn(x) x, indices, commit_loss = self.rq(x) if exists(self.decoder_attn): x = self.decoder_attn(x) + x x = self.decoder_attn(x) x = rearrange(x, 'b n c -> b c n') Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.4.8', version = '0.5.0', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading