Commit e28ff435 authored by Phil Wang's avatar Phil Wang
Browse files

instantiate own ResidualVQ for cross entropy loss at natural speech 2

parent b973f1a2
Loading
Loading
Loading
Loading
+32 −6
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@ from einops import rearrange, pack, unpack
import torch
from torch import nn

from vector_quantize_pytorch import ResidualVQ

from encodec import EncodecModel
from encodec.utils import _linear_overlap_add

@@ -19,7 +21,8 @@ class EncodecWrapper(nn.Module):
    -

    """
    def __init__(self,
    def __init__(
        self,
        target_sample_hz = 24000,
        strides = (2, 4, 5, 8),
        num_quantizers = 8,
@@ -41,11 +44,34 @@ class EncodecWrapper(nn.Module):
        self.num_quantizers = num_quantizers
        self.strides = strides # used in seq_len_multiple_of

        # cross entropy loss to indices passed in on l2 distance logits introduced in vector-quantize-pytorch 1.2.2

        self.rq = ResidualVQ(
            dim = 128,
            codebook_size = 1024,
            num_quantizers = 8
        )

        # copy codebook over to ResidualVQ for cross entropy loss logic from naturalspeech2
        # luckily, it seems Meta AI basically used my ResidualVQ code verbatim. makes porting it over easy

        for encodec_rq_layer, rq_layer in zip(self.model.quantizer.vq.layers, self.rq.layers):
            encodec_codebook = dict(encodec_rq_layer._codebook.named_buffers()).get('embed')
            vq_codebook = dict(rq_layer._codebook.named_buffers()).get('embed')

            encodec_codebook = rearrange(encodec_codebook, '... -> 1 ...')
            vq_codebook.copy_(encodec_codebook)

    @property
    def seq_len_multiple_of(self):
        return reduce(lambda x, y: x * y, self.strides)

    def forward(self, x, return_encoded = False, **kwargs):
    def forward(
        self,
        x,
        return_encoded = False,
        **kwargs
    ):

        x, ps = pack([x], '* n')

+1 −1
Original line number Diff line number Diff line
__version__ = '0.27.3'
__version__ = '0.27.4'
+1 −1
Original line number Diff line number Diff line
@@ -34,7 +34,7 @@ setup(
    'torchaudio',
    'transformers',
    'tqdm',
    'vector-quantize-pytorch>=1.0.6'
    'vector-quantize-pytorch>=1.2.2'
  ],
  classifiers=[
    'Development Status :: 4 - Beta',