Commit af0564d4 authored by Phil Wang's avatar Phil Wang
Browse files

add an adapter class for fairseq vq-wav2vec, make sure training of semantic...

add an adapter class for fairseq vq-wav2vec, make sure training of semantic and coarse transformers can happen end to end
parent 2ce09315
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -65,6 +65,7 @@ loss.backward()
- [ ] make sure full inference with or without prompting works on the `AudioLM` class
- [ ] offer option to weight tie coarse, fine, and semantic embeddings across the 3 hierarchical transformers
- [ ] DRY a little at the end
- [ ] figure out how to suppress logging in fairseq

## Citations

+2 −0
Original line number Diff line number Diff line
@@ -3,3 +3,5 @@ from audiolm_pytorch.audiolm_pytorch import SoundStream

from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransformer, FineTransformer
from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper

from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
+28 −4
Original line number Diff line number Diff line
@@ -11,6 +11,8 @@ from einops import rearrange, repeat

from vector_quantize_pytorch import ResidualVQ

from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec

# helper functions

def exists(val):
@@ -549,6 +551,7 @@ class SemanticTransformer(nn.Module):
        *,
        num_semantic_tokens,
        dim,
        wav2vec: Optional[FairseqVQWav2Vec] = None,
        **kwargs
    ):
        super().__init__()
@@ -556,14 +559,23 @@ class SemanticTransformer(nn.Module):

        self.semantic_embedding = nn.Embedding(num_semantic_tokens, dim)

        self.wav2vec = wav2vec
        self.transformer = Transformer(dim = dim, **kwargs)
        self.to_logits = nn.Linear(dim, num_semantic_tokens)

    def forward(
        self,
        ids,
        *,
        raw_wave = None,
        ids = None,
        return_loss = False
    ):
        assert exists(raw_wave) ^ exists(ids)

        if not exists(ids):
            assert exists(self.wav2vec)
            ids = self.wav2vec(raw_wave)

        if return_loss:
            labels, ids = ids.clone(), ids[:, :-1]

@@ -594,6 +606,7 @@ class CoarseTransformer(nn.Module):
        codebook_size,
        num_coarse_quantizers,
        dim,
        wav2vec: Optional[FairseqVQWav2Vec] = None,
        **kwargs
    ):
        super().__init__()
@@ -602,6 +615,7 @@ class CoarseTransformer(nn.Module):
        self.semantic_embedding = nn.Embedding(num_semantic_tokens, dim)
        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size, dim)

        self.wav2vec = wav2vec
        self.transformer = Transformer(dim = dim, **kwargs)

        self.codebook_size = codebook_size
@@ -612,6 +626,7 @@ class CoarseTransformer(nn.Module):

    def forward(
        self,
        *,
        semantic_token_ids,
        coarse_token_ids,
    ):
@@ -825,10 +840,13 @@ class CoarseTransformerWrapper(nn.Module):
        *,
        transformer: FineTransformer,
        soundstream: Optional[SoundStream]  = None,
        wav2vec: Optional[FairseqVQWav2Vec] = None,
        num_coarse_quantize = 3
    ):
        super().__init__()
        self.soundstream = soundstream
        self.wav2vec = wav2vec

        self.transformer = transformer

        assert num_coarse_quantize > 0
@@ -837,14 +855,20 @@ class CoarseTransformerWrapper(nn.Module):
    def forward(
        self,
        *,
        semantic_token_ids,
        semantic_token_ids = None,
        raw_wave = None,
        coarse_token_ids = None,
        return_loss = False
    ):
        assert exists(raw_wave) ^ exists(coarse_token_ids), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
        assert exists(raw_wave) or exists(semantic_token_ids), 'either raw waveform (raw_wave) is given or semantic token ids are given (semantic_token_ids)'
        assert exists(raw_wave) or exists(coarse_token_ids), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
        assert not all(map(exists, (raw_wave, semantic_token_ids, coarse_token_ids)))

        if exists(raw_wave):
        if not exists(semantic_token_ids):
            assert exists(self.wav2vec), 'VQWav2Vec must be be provided if given raw wave for training'
            semantic_token_ids = self.wav2vec(raw_wave)

        if not exists(coarse_token_ids):
            assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'

            with torch.no_grad():
+41 −0
Original line number Diff line number Diff line
from pathlib import Path

import torch
from torch import nn
from einops import rearrange

import fairseq

class FairseqVQWav2Vec(nn.Module):
    def __init__(
        self,
        checkpoint_path
    ):
        super().__init__()
        path = Path(checkpoint_path)
        assert path.exists(), f'path {checkpoint_path} does not exist'

        checkpoint = torch.load(checkpoint_path)
        load_model_input = {checkpoint_path: checkpoint}
        model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)

        self.model = model[0]
        self.model.eval()

    @property
    def groups(self):
        return self.model.vector_quantizer.groups

    @property
    def codebook_size(self):
        return self.model.vector_quantizer.embedding.shape[0]

    @torch.no_grad()
    def forward(self, wav_input, flatten = True):
        embed = self.model.feature_extractor(wav_input)
        _, codebook_indices = self.model.vector_quantizer.forward_idx(embed)

        if not flatten:
            return codebook_indices

        return rearrange(codebook_indices, 'b ... -> b (...)')
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.4',
  version = '0.0.5',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',