Unverified Commit edf792b8 authored by Phil Wang's avatar Phil Wang Committed by GitHub
Browse files

Merge pull request #135 from LWprogramming/encodec_support

Encodec support
parents 26762007 310deea3
Loading
Loading
Loading
Loading
+12 −5
Original line number Diff line number Diff line
@@ -42,9 +42,16 @@ $ pip install audiolm-pytorch

## Usage

### SoundStream
### SoundStream & Encodec

First, `SoundStream` needs to be trained on a large corpus of audio data
There are two options for the neural codec. If you want to use the pretrained 24kHz Encodec, just create an Encodec object as follows:
```python
from audiolm_pytorch import EncodecWrapper
encodec = EncodecWrapper()
# Now you can use the encodec variable in the same way you'd use the soundstream variables below.
```

Otherwise, to stay more true to the original paper, you can use `SoundStream`. First, `SoundStream` needs to be trained on a large corpus of audio data

```python
from audiolm_pytorch import SoundStream, SoundStreamTrainer
@@ -152,7 +159,7 @@ coarse_transformer = CoarseTransformer(

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    soundstream = soundstream,
    codec = soundstream,
    wav2vec = wav2vec,
    folder = '/path/to/audio/files',
    batch_size = 1,
@@ -181,7 +188,7 @@ fine_transformer = FineTransformer(

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    soundstream = soundstream,
    codec = soundstream,
    folder = '/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
@@ -198,7 +205,7 @@ from audiolm_pytorch import AudioLM

audiolm = AudioLM(
    wav2vec = wav2vec,
    soundstream = soundstream,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
+1 −1
Original line number Diff line number Diff line
from audiolm_pytorch.audiolm_pytorch import AudioLM
from audiolm_pytorch.soundstream import SoundStream, AudioLMSoundStream, MusicLMSoundStream
from audiolm_pytorch.soundstream import SoundStream, AudioLMSoundStream, MusicLMSoundStream, EncodecWrapper

from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransformer, FineTransformer
from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper, SemanticTransformerWrapper
+26 −26
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

from torchaudio.functional import resample

from audiolm_pytorch.soundstream import SoundStream
from audiolm_pytorch.soundstream import SoundStream, EncodecWrapper
from audiolm_pytorch.utils import AudioConditionerBase

from tqdm import tqdm
@@ -1268,7 +1268,7 @@ class CoarseTransformerWrapper(nn.Module):
        self,
        *,
        transformer: CoarseTransformer,
        soundstream: Optional[SoundStream]  = None,
        codec: Optional[Union[SoundStream, EncodecWrapper]]  = None,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        pad_id = -1,
@@ -1277,7 +1277,7 @@ class CoarseTransformerWrapper(nn.Module):
        mask_prob = 0.15
    ):
        super().__init__()
        self.soundstream = soundstream
        self.codec = codec
        self.wav2vec = wav2vec

        self.transformer = transformer
@@ -1369,9 +1369,9 @@ class CoarseTransformerWrapper(nn.Module):
        if not reconstruct_wave:
            return sampled_coarse_token_ids

        assert exists(self.soundstream)
        assert exists(self.codec)

        wav = self.soundstream.decode_from_codebook_indices(sampled_coarse_token_ids)
        wav = self.codec.decode_from_codebook_indices(sampled_coarse_token_ids)
        return rearrange(wav, 'b 1 n -> b n')

    def forward(
@@ -1379,7 +1379,7 @@ class CoarseTransformerWrapper(nn.Module):
        *,
        semantic_token_ids = None,
        raw_wave = None,
        raw_wave_for_soundstream = None,
        raw_wave_for_codec = None,
        text = None,
        text_embeds = None,
        coarse_token_ids = None,
@@ -1388,10 +1388,10 @@ class CoarseTransformerWrapper(nn.Module):
    ):
        assert exists(raw_wave) or exists(semantic_token_ids), 'either raw waveform (raw_wave) is given or semantic token ids are given (semantic_token_ids)'

        raw_wave_for_soundstream = default(raw_wave_for_soundstream, raw_wave)
        assert exists(raw_wave_for_soundstream) or exists(coarse_token_ids), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
        raw_wave_for_codec = default(raw_wave_for_codec, raw_wave)
        assert exists(raw_wave_for_codec) or exists(coarse_token_ids), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'

        assert not all(map(exists, (raw_wave, raw_wave_for_soundstream, semantic_token_ids, coarse_token_ids)))
        assert not all(map(exists, (raw_wave, raw_wave_for_codec, semantic_token_ids, coarse_token_ids)))

        if exists(self.audio_conditioner):
            assert exists(raw_wave)
@@ -1403,11 +1403,11 @@ class CoarseTransformerWrapper(nn.Module):
            semantic_token_ids = self.wav2vec(raw_wave, flatten = False)

        if not exists(coarse_token_ids):
            assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'
            assert exists(self.codec), 'SoundStream must be provided if given raw wave for training'

            with torch.no_grad():
                self.soundstream.eval()
                _, indices, _ = self.soundstream(raw_wave_for_soundstream, return_encoded = True)
                self.codec.eval()
                _, indices, _ = self.codec(raw_wave_for_codec, return_encoded = True)
                coarse_token_ids, _ = indices[..., :self.num_coarse_quantizers], indices[..., self.num_coarse_quantizers:]

        semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)')
@@ -1484,14 +1484,14 @@ class FineTransformerWrapper(nn.Module):
        self,
        *,
        transformer: FineTransformer,
        soundstream: Optional[SoundStream] = None,
        codec: Optional[Union[SoundStream, EncodecWrapper]] = None,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        coarse_cross_entropy_loss_weight = 1.,
        pad_id = -1,
        mask_prob = 0.15
    ):
        super().__init__()
        self.soundstream = soundstream
        self.codec = codec

        self.transformer = transformer
        self.audio_conditioner = audio_conditioner
@@ -1501,8 +1501,8 @@ class FineTransformerWrapper(nn.Module):
        self.num_fine_quantizers = transformer.num_fine_quantizers
        self.num_coarse_quantizers = transformer.num_coarse_quantizers

        if exists(soundstream):
            assert (self.num_fine_quantizers + self.num_coarse_quantizers) == soundstream.num_quantizers, 'number of fine and coarse quantizers on fine transformer must add up to total number of quantizers on soundstream'
        if exists(codec):
            assert (self.num_fine_quantizers + self.num_coarse_quantizers) == codec.num_quantizers, 'number of fine and coarse quantizers on fine transformer must add up to total number of quantizers on codec'

        self.eos_id = transformer.eos_id

@@ -1596,13 +1596,13 @@ class FineTransformerWrapper(nn.Module):
        if not reconstruct_wave:
            return sampled_fine_token_ids

        # reconstruct the wave using soundstream, concatting the fine and coarse token ids together first across quantization dimension
        # reconstruct the wave using codec, concatting the fine and coarse token ids together first across quantization dimension

        assert exists(self.soundstream)
        assert exists(self.codec)

        coarse_and_fine_ids = torch.cat((coarse_token_ids, sampled_fine_token_ids), dim = -1)

        wav = self.soundstream.decode_from_codebook_indices(coarse_and_fine_ids)
        wav = self.codec.decode_from_codebook_indices(coarse_and_fine_ids)
        return rearrange(wav, 'b 1 n -> b n')

    def forward(
@@ -1625,11 +1625,11 @@ class FineTransformerWrapper(nn.Module):
            text_embeds = self.audio_conditioner(wavs = raw_wave, namespace = 'fine') # technically audio embeds, but shared text-audio joint embedding space for mulan

        if exists(raw_wave):
            assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'
            assert exists(self.codec), 'SoundStream must be provided if given raw wave for training'

            with torch.no_grad():
                self.soundstream.eval()
                _, token_ids, _ = self.soundstream(raw_wave, return_encoded = True)
                self.codec.eval()
                _, token_ids, _ = self.codec(raw_wave, return_encoded = True)

        if exists(token_ids):
            coarse_token_ids, fine_token_ids = token_ids[..., :self.num_coarse_quantizers], token_ids[..., self.num_coarse_quantizers:]
@@ -1706,7 +1706,7 @@ class AudioLM(nn.Module):
        self,
        *,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]], 
        soundstream: SoundStream,
        codec: Union[SoundStream, EncodecWrapper],
        semantic_transformer: SemanticTransformer,
        coarse_transformer: CoarseTransformer,
        fine_transformer: FineTransformer,
@@ -1720,7 +1720,7 @@ class AudioLM(nn.Module):
        assert semantic_transformer.num_semantic_tokens == coarse_transformer.num_semantic_tokens
        assert coarse_transformer.codebook_size == fine_transformer.codebook_size
        assert coarse_transformer.num_coarse_quantizers == fine_transformer.num_coarse_quantizers
        assert (fine_transformer.num_coarse_quantizers + fine_transformer.num_fine_quantizers) == soundstream.num_quantizers
        assert (fine_transformer.num_coarse_quantizers + fine_transformer.num_fine_quantizers) == codec.num_quantizers

        self.semantic_has_condition = semantic_transformer.has_condition
        self.coarse_has_condition = coarse_transformer.has_condition
@@ -1736,14 +1736,14 @@ class AudioLM(nn.Module):

        self.coarse = CoarseTransformerWrapper(
            wav2vec = wav2vec,
            soundstream = soundstream,
            codec= codec,
            transformer = coarse_transformer,
            audio_conditioner = audio_conditioner,
            unique_consecutive = unique_consecutive
        )

        self.fine = FineTransformerWrapper(
            soundstream = soundstream,
            codec= codec,
            transformer = fine_transformer,
            audio_conditioner = audio_conditioner
        )
+86 −0
Original line number Diff line number Diff line
@@ -29,6 +29,9 @@ parsed_version = version.parse(__version__)

import pickle

from encodec import EncodecModel
from encodec.utils import convert_audio, _linear_overlap_add

# helper functions

def exists(val):
@@ -398,6 +401,89 @@ class LocalTransformer(nn.Module):
            x = ff(x) + x

        return x
class EncodecWrapper(nn.Module):
    """
    Support pretrained 24kHz Encodec by Meta AI, if you want to skip training SoundStream.

    TODO:
    - see if we need to keep the scaled version and somehow persist the scale factors for when we need to decode? Right
        now I'm just setting self.model.normalize = False to sidestep all of that
    - see if we can use the 48kHz model, which is specifically for music. Right now we're using the 24kHz model because
        that's what was used in MusicLM and avoids any resampling issues.
    -

    """
    def __init__(self,
                 target_sample_hz=24000,
                 strides=(2,4,5,8),
                 num_quantizers=8,
                 ):
        super().__init__()
        # Instantiate a pretrained EnCodec model
        self.model = EncodecModel.encodec_model_24khz()
        self.model.normalize = False # this means we don't need to scale codes e.g. when running model.encode(wav)

        # bandwidth affects num quantizers used: https://github.com/facebookresearch/encodec/pull/41
        self.model.set_target_bandwidth(6.0)
        assert num_quantizers == 8, "assuming 8 quantizers for now, see bandwidth comment above"

        # Fields that SoundStream has that get used externally. We replicate them here.
        self.target_sample_hz = target_sample_hz
        assert self.target_sample_hz == 24000, "haven't done anything with non-24kHz yet"
        self.num_quantizers = num_quantizers
        self.strides = strides # used in seq_len_multiple_of

    @property
    def seq_len_multiple_of(self):
        return functools.reduce(lambda x, y: x * y, self.strides)

    def forward(self, x, x_sampling_rate=24000, **kwargs):
        # kwargs for stuff like return_encoded=True, which SoundStream uses but Encodec doesn't
        assert not self.model.training, "Encodec is pretrained and should never be called outside eval mode."
        # convert_audio up-samples if necessary, e.g. if wav has n samples at 16 kHz and model is 48 kHz,
        # then resulting wav has 3n samples because you do n * 48/16
        # Note: this is a bit of a hack but we avoid any resampling issues here if we just try 24kHz throughout
        # which makes convert_audio a no-op
        wav = convert_audio(x, x_sampling_rate, self.model.sample_rate, self.model.channels)
        wav = wav.unsqueeze(0)
        # Extract discrete codes from EnCodec
        with torch.no_grad():
            encoded_frames = self.model.encode(wav)
        codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)  # [batch, num_quantizers, timesteps]
        # in original soundstream, is x, indices, commit_loss. But we only use indices in eval mode, so just keep that.
        return None, codes, None

    def decode_from_codebook_indices(self, quantized_indices):
        # Input: batch x num tokens x num quantizers
        # Output: batch x 1 x num samples

        assert self.model.sample_rate == 24000,\
            "if changing to 48kHz, that model segments its audio into lengths of 1.0 second with 1% overlap, whereas " \
            "the 24kHz doesn't segment at all. this means the frame decode logic might change; this is a reminder to " \
            "double check that."
        # Since 24kHz pretrained doesn't do any segmenting, we have all the frames already (1 frame = 1 token in quantized_indices)

        # The following code is hacked in from self.model.decode() (Encodec version 0.1.1) where we skip the part about
        # scaling.
        # Shape: 1 x (num_frames * stride product). 1 because we have 1 frame (because no segmenting)
        frames = self._decode_frame(quantized_indices)
        result = _linear_overlap_add(frames, self.model.segment_stride or 1)
        # TODO: I'm not overly pleased with this because when this function gets called, we just rearrange the result
        #   back to b n anyways, but we'll keep this as a temporary hack just to make things work for now
        return rearrange(result, 'b n -> b 1 n')

    def _decode_frame(self, quantized_indices):
        # The following code is hacked in from self.model._decode_frame() (Encodec version 0.1.1) where we assume we've
        # already unwrapped the EncodedFrame
        # Input: batch x num tokens x num quantizers
        # Output: batch x new_num_samples, where new_num_samples is num_frames * stride product (may be slightly
        # larger than original num samples as a result, because the last frame might not be "fully filled" with samples
        # if num_samples doesn't divide perfectly).
        # num_frames == the number of acoustic tokens you have, one token per frame
        codes = rearrange(quantized_indices, 'b t q -> q b t')
        emb = self.model.quantizer.decode(codes)
        # emb shape: batch x self.model.quantizer.dimension x T. Note self.model.quantizer.dimension is the embedding dimension
        return self.model.decoder(emb)

class SoundStream(nn.Module):
    def __init__(
+14 −14
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ from audiolm_pytorch.optimizer import get_optimizer

from ema_pytorch import EMA

from audiolm_pytorch.soundstream import SoundStream
from audiolm_pytorch.soundstream import SoundStream, EncodecWrapper

from audiolm_pytorch.audiolm_pytorch import (
    SemanticTransformer,
@@ -735,14 +735,14 @@ class CoarseTransformerTrainer(nn.Module):
    def __init__(
        self,
        transformer: CoarseTransformer,
        soundstream: SoundStream,
        codec: Union[SoundStream, EncodecWrapper],
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
        *,
        num_train_steps,
        batch_size,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        dataset: Optional[Dataset] = None,
        ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_soundstream', 'text'),
        ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_codec', 'text'),
        data_max_length = None,
        data_max_length_seconds = None,
        folder = None,
@@ -762,12 +762,12 @@ class CoarseTransformerTrainer(nn.Module):
        self.accelerator = Accelerator(**accelerate_kwargs)

        self.transformer = transformer
        self.soundstream = soundstream
        self.codec = codec
        self.wav2vec = wav2vec
        self.audio_conditioner = audio_conditioner

        self.train_wrapper = CoarseTransformerWrapper(
            soundstream = soundstream,
            codec = codec,
            wav2vec = wav2vec,
            transformer = transformer,
            audio_conditioner = audio_conditioner
@@ -797,16 +797,16 @@ class CoarseTransformerTrainer(nn.Module):
            assert not (exists(data_max_length) and exists(data_max_length_seconds))

            if exists(data_max_length_seconds):
                data_max_length = tuple(data_max_length_seconds * hz for hz in (wav2vec.target_sample_hz, soundstream.target_sample_hz))
                data_max_length = tuple(data_max_length_seconds * hz for hz in (wav2vec.target_sample_hz, codec.target_sample_hz))

            self.ds = SoundDataset(
                folder,
                max_length = data_max_length,
                target_sample_hz = (
                    wav2vec.target_sample_hz,
                    soundstream.target_sample_hz
                    codec.target_sample_hz
                ), # need 2 waves resampled differently here
                seq_len_multiple_of = soundstream.seq_len_multiple_of
                seq_len_multiple_of = codec.seq_len_multiple_of
            )

        self.ds_fields = ds_fields
@@ -985,7 +985,7 @@ class FineTransformerTrainer(nn.Module):
    def __init__(
        self,
        transformer: FineTransformer,
        soundstream: SoundStream,
        codec: Union[SoundStream, EncodecWrapper],
        *,
        num_train_steps,
        batch_size,
@@ -1011,11 +1011,11 @@ class FineTransformerTrainer(nn.Module):
        self.accelerator = Accelerator(**accelerate_kwargs)

        self.transformer = transformer
        self.soundstream = soundstream
        self.codec = codec
        self.audio_conditioner = audio_conditioner

        self.train_wrapper = FineTransformerWrapper(
            soundstream = soundstream,
            codec = codec,
            transformer = transformer,
            audio_conditioner = audio_conditioner
        )
@@ -1044,13 +1044,13 @@ class FineTransformerTrainer(nn.Module):
            assert not (exists(data_max_length) and exists(data_max_length_seconds))

            if exists(data_max_length_seconds):
                data_max_length = data_max_length_seconds * soundstream.target_sample_hz
                data_max_length = data_max_length_seconds * codec.target_sample_hz

            self.ds = SoundDataset(
                folder,
                max_length = data_max_length,
                target_sample_hz = soundstream.target_sample_hz,
                seq_len_multiple_of = soundstream.seq_len_multiple_of
                target_sample_hz = codec.target_sample_hz,
                seq_len_multiple_of = codec.seq_len_multiple_of
            )

        self.ds_fields = None
Loading