Commit 664a97da authored by Phil Wang's avatar Phil Wang
Browse files

one more modification so encodec wrapper can be used at naturalspeech2

parent 6e1e54bd
Loading
Loading
Loading
Loading
+13 −2
Original line number Diff line number Diff line
@@ -41,7 +41,7 @@ class EncodecWrapper(nn.Module):
    def seq_len_multiple_of(self):
        return reduce(lambda x, y: x * y, self.strides)

    def forward(self, x, **kwargs):
    def forward(self, x, return_encoded = False, **kwargs):
        # kwargs for stuff like return_encoded=True, which SoundStream uses but Encodec doesn't
        assert not self.model.training, "Encodec is pretrained and should never be called outside eval mode."
        # Unlike in the Encodec sample code in its README, x has already been resampled so we don't need to call
@@ -60,7 +60,14 @@ class EncodecWrapper(nn.Module):
        # transformer code that uses codec expects codes to be [batch, timesteps, num_quantizers]
        codes = rearrange(codes, 'b q n -> b n q')  # result: [batch, timesteps, num_quantizers]
        # in original soundstream, is x, indices, commit_loss. But we only use indices in eval mode, so just keep that.
        return None, codes, None

        # allow for returning of sum of quantized embeddings

        emb = None
        if return_encoded:
            emb = self.get_emb_from_indices(codes)

        return emb, codes, None

    def decode_from_codebook_indices(self, quantized_indices):
        # Input: batch x num tokens x num quantizers
@@ -81,6 +88,10 @@ class EncodecWrapper(nn.Module):
        #   back to b n anyways, but we'll keep this as a temporary hack just to make things work for now
        return rearrange(result, 'b n -> b 1 n')

    def get_emb_from_indices(self, indices):
        codes = rearrange(indices, 'b t q -> q b t')
        return self.model.quantizer.decode(codes)

    def decode(self, emb):
        return self.model.decoder(emb)

+1 −1
Original line number Diff line number Diff line
__version__ = '0.27.1'
__version__ = '0.27.2'