Commit 5268c045 authored by Leon Wu's avatar Leon Wu
Browse files

Remove unnecessary convert_audio

parent 9f2d1532
Loading
Loading
Loading
Loading
+8 −8
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from einops import rearrange
import torch
from torch import nn
from encodec import EncodecModel
from encodec.utils import convert_audio, _linear_overlap_add
from encodec.utils import _linear_overlap_add

class EncodecWrapper(nn.Module):
    """
@@ -41,15 +41,15 @@ class EncodecWrapper(nn.Module):
    def seq_len_multiple_of(self):
        return reduce(lambda x, y: x * y, self.strides)

    def forward(self, x, x_sampling_rate=24000, **kwargs):
    def forward(self, x, **kwargs):
        # kwargs for stuff like return_encoded=True, which SoundStream uses but Encodec doesn't
        assert not self.model.training, "Encodec is pretrained and should never be called outside eval mode."
        # convert_audio up-samples if necessary, e.g. if wav has n samples at 16 kHz and model is 48 kHz,
        # then resulting wav has 3n samples because you do n * 48/16
        # Note: this is a bit of a hack but we avoid any resampling issues here if we just try 24kHz throughout
        # which makes convert_audio a no-op
        wav = convert_audio(x, x_sampling_rate, self.model.sample_rate, self.model.channels)
        wav = wav.unsqueeze(0)
        # Unlike in the Encodec sample code in its README, x has already been resampled so we don't need to call
        # convert_audio and unsqueeze. The convert_audio function also doesn't play nicely with batches.

        # b = batch, t = timesteps, 1 channel for the 24kHz model, 2 channels for the 48kHz model
        wav = rearrange(x, f'b t -> b {self.model.channels} t')

        # Extract discrete codes from EnCodec
        with torch.no_grad():
            encoded_frames = self.model.encode(wav)