Loading audiolm_pytorch/audiolm_pytorch.py +64 −3 Original line number Diff line number Diff line Loading @@ -137,13 +137,16 @@ class STFTDiscriminator(nn.Module): def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') # reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: ''' reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: The STFT-based discriminator is illustrated in Figure 4 and operates on a single scale, computing the STFT with a window length of W = 1024 samples and a hop length of H = 256 samples ''' x = torch.view_as_complex(torch.stft(x,1024, hop_length=256,win_length=1024)) x = rearrange(x, 'b ... -> b 1 ...') Loading Loading @@ -243,7 +246,9 @@ class SoundStream(nn.Module): discr_multi_scales = (1, 0.5, 0.25), recon_loss_weight = 1., adversarial_loss_weight = 1., feature_loss_weight = 100 feature_loss_weight = 100, quantize_dropout = True, quantize_dropout_cutoff_index = 0 ): super().__init__() self.single_channel = input_channels == 1 Loading @@ -268,7 +273,9 @@ class SoundStream(nn.Module): num_quantizers = rq_num_quantizers, codebook_size = codebook_size, kmeans_init = True, threshold_ema_dead_code = 2 threshold_ema_dead_code = 2, quantize_dropout = quantize_dropout, quantize_dropout_cutoff_index = quantize_dropout_cutoff_index ) decoder_blocks = [] Loading Loading @@ -518,6 +525,60 @@ class Transformer(nn.Module): return self.norm(x) # the three hierarchical transformers class SemanticTransformer(nn.Module): def __init__( self, *, num_semantic_tokens, dim, **kwargs ): super().__init__() def forward( self, semantic_token_ids ): raise NotImplemented class CoarseTransformer(nn.Module): def __init__( self, *, num_semantic_tokens, num_coarse_tokens, dim, **kwargs ): super().__init__() def forward( self, semantic_token_ids, coarse_token_ids, ): raise NotImplemented class FineTransformer(nn.Module): def __init__( self, *, num_coarse_tokens, num_fine_tokens, dim, **kwargs ): super().__init__() def forward( self, coarse_token_ids, fine_token_ids ): raise NotImplemented # audio LM class AudioLM(nn.Module): Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -21,7 +21,7 @@ setup( 'einops>=0.4', 'ema-pytorch', 'torch>=1.6', 'vector-quantize-pytorch>=0.10.2' 'vector-quantize-pytorch>=0.10.5' ], classifiers=[ 'Development Status :: 4 - Beta', Loading Loading
audiolm_pytorch/audiolm_pytorch.py +64 −3 Original line number Diff line number Diff line Loading @@ -137,13 +137,16 @@ class STFTDiscriminator(nn.Module): def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') # reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: ''' reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: The STFT-based discriminator is illustrated in Figure 4 and operates on a single scale, computing the STFT with a window length of W = 1024 samples and a hop length of H = 256 samples ''' x = torch.view_as_complex(torch.stft(x,1024, hop_length=256,win_length=1024)) x = rearrange(x, 'b ... -> b 1 ...') Loading Loading @@ -243,7 +246,9 @@ class SoundStream(nn.Module): discr_multi_scales = (1, 0.5, 0.25), recon_loss_weight = 1., adversarial_loss_weight = 1., feature_loss_weight = 100 feature_loss_weight = 100, quantize_dropout = True, quantize_dropout_cutoff_index = 0 ): super().__init__() self.single_channel = input_channels == 1 Loading @@ -268,7 +273,9 @@ class SoundStream(nn.Module): num_quantizers = rq_num_quantizers, codebook_size = codebook_size, kmeans_init = True, threshold_ema_dead_code = 2 threshold_ema_dead_code = 2, quantize_dropout = quantize_dropout, quantize_dropout_cutoff_index = quantize_dropout_cutoff_index ) decoder_blocks = [] Loading Loading @@ -518,6 +525,60 @@ class Transformer(nn.Module): return self.norm(x) # the three hierarchical transformers class SemanticTransformer(nn.Module): def __init__( self, *, num_semantic_tokens, dim, **kwargs ): super().__init__() def forward( self, semantic_token_ids ): raise NotImplemented class CoarseTransformer(nn.Module): def __init__( self, *, num_semantic_tokens, num_coarse_tokens, dim, **kwargs ): super().__init__() def forward( self, semantic_token_ids, coarse_token_ids, ): raise NotImplemented class FineTransformer(nn.Module): def __init__( self, *, num_coarse_tokens, num_fine_tokens, dim, **kwargs ): super().__init__() def forward( self, coarse_token_ids, fine_token_ids ): raise NotImplemented # audio LM class AudioLM(nn.Module): Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -21,7 +21,7 @@ setup( 'einops>=0.4', 'ema-pytorch', 'torch>=1.6', 'vector-quantize-pytorch>=0.10.2' 'vector-quantize-pytorch>=0.10.5' ], classifiers=[ 'Development Status :: 4 - Beta', Loading