Commit b973f1a2 authored by Phil Wang's avatar Phil Wang
Browse files

expose codebook dimensions on encodec and soundstream, so naturalspeech can...

expose codebook dimensions on encodec and soundstream, so naturalspeech can auto-set hyperparameters
parent 664a97da
Loading
Loading
Loading
Loading
+14 −2
Original line number Diff line number Diff line
from functools import reduce
from einops import rearrange
from einops import rearrange, pack, unpack

import torch
from torch import nn

from encodec import EncodecModel
from encodec.utils import _linear_overlap_add

@@ -34,6 +36,8 @@ class EncodecWrapper(nn.Module):
        # Fields that SoundStream has that get used externally. We replicate them here.
        self.target_sample_hz = target_sample_hz
        assert self.target_sample_hz == 24000, "haven't done anything with non-24kHz yet"

        self.codebook_dim = 128
        self.num_quantizers = num_quantizers
        self.strides = strides # used in seq_len_multiple_of

@@ -42,6 +46,9 @@ class EncodecWrapper(nn.Module):
        return reduce(lambda x, y: x * y, self.strides)

    def forward(self, x, return_encoded = False, **kwargs):

        x, ps = pack([x], '* n')

        # kwargs for stuff like return_encoded=True, which SoundStream uses but Encodec doesn't
        assert not self.model.training, "Encodec is pretrained and should never be called outside eval mode."
        # Unlike in the Encodec sample code in its README, x has already been resampled so we don't need to call
@@ -67,6 +74,9 @@ class EncodecWrapper(nn.Module):
        if return_encoded:
            emb = self.get_emb_from_indices(codes)

        emb, = unpack(emb, ps, '* n c')
        codes, = unpack(codes, ps, '* n q')

        return emb, codes, None

    def decode_from_codebook_indices(self, quantized_indices):
@@ -90,9 +100,11 @@ class EncodecWrapper(nn.Module):

    def get_emb_from_indices(self, indices):
        codes = rearrange(indices, 'b t q -> q b t')
        return self.model.quantizer.decode(codes)
        emb = self.model.quantizer.decode(codes)
        return rearrange(emb, 'b c n -> b n c')

    def decode(self, emb):
        emb = rearrange(emb, 'b n c -> b c n')
        return self.model.decoder(emb)

    def _decode_frame(self, quantized_indices):
+2 −0
Original line number Diff line number Diff line
@@ -484,6 +484,8 @@ class SoundStream(nn.Module):

        self.num_quantizers = rq_num_quantizers

        self.codebook_dim = codebook_dim

        self.rq = ResidualVQ(
            dim = codebook_dim,
            num_quantizers = rq_num_quantizers,
+1 −1
Original line number Diff line number Diff line
__version__ = '0.27.2'
__version__ = '0.27.3'