Commit 90904ee5 authored by Phil Wang's avatar Phil Wang
Browse files

update to latest vq library, make sure residual quantizer can dropout during...

update to latest vq library, make sure residual quantizer can dropout during training of soundstream
parent 3ae3a0dd
Loading
Loading
Loading
Loading
+64 −3
Original line number Diff line number Diff line
@@ -137,13 +137,16 @@ class STFTDiscriminator(nn.Module):

    def forward(self, x, return_intermediates = False):
        x = rearrange(x, 'b 1 n -> b n')
        # reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows:

        '''
        reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows:

        The STFT-based discriminator is illustrated in Figure 4
        and operates on a single scale, computing the STFT with a
        window length of W = 1024 samples and a hop length of
        H = 256 samples
        '''

        x = torch.view_as_complex(torch.stft(x,1024, hop_length=256,win_length=1024))
        x = rearrange(x, 'b ... -> b 1 ...')

@@ -243,7 +246,9 @@ class SoundStream(nn.Module):
        discr_multi_scales = (1, 0.5, 0.25),
        recon_loss_weight = 1.,
        adversarial_loss_weight = 1.,
        feature_loss_weight = 100
        feature_loss_weight = 100,
        quantize_dropout = True,
        quantize_dropout_cutoff_index = 0
    ):
        super().__init__()
        self.single_channel = input_channels == 1
@@ -268,7 +273,9 @@ class SoundStream(nn.Module):
            num_quantizers = rq_num_quantizers,
            codebook_size = codebook_size,
            kmeans_init = True,
            threshold_ema_dead_code = 2
            threshold_ema_dead_code = 2,
            quantize_dropout = quantize_dropout,
            quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
        )

        decoder_blocks = []
@@ -518,6 +525,60 @@ class Transformer(nn.Module):

        return self.norm(x)

# the three hierarchical transformers

class SemanticTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_semantic_tokens,
        dim,
        **kwargs
    ):
        super().__init__()

    def forward(
        self,
        semantic_token_ids
    ):
        raise NotImplemented

class CoarseTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_semantic_tokens,
        num_coarse_tokens,
        dim,
        **kwargs
    ):
        super().__init__()

    def forward(
        self,
        semantic_token_ids,
        coarse_token_ids,
    ):
        raise NotImplemented

class FineTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_coarse_tokens,
        num_fine_tokens,
        dim,
        **kwargs
    ):
        super().__init__()

    def forward(
        self,
        coarse_token_ids,
        fine_token_ids
    ):
        raise NotImplemented

# audio LM

class AudioLM(nn.Module):
+1 −1
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ setup(
    'einops>=0.4',
    'ema-pytorch',
    'torch>=1.6',
    'vector-quantize-pytorch>=0.10.2'
    'vector-quantize-pytorch>=0.10.5'
  ],
  classifiers=[
    'Development Status :: 4 - Beta',