Commit a11722e6 authored by Phil Wang's avatar Phil Wang
Browse files

listen to @eonglints and add hubert with kmeans as an option

parent ed313d3a
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -5,3 +5,4 @@ from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransform
from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper

from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans
+8 −7
Original line number Diff line number Diff line
import math
from functools import partial
from typing import Optional
from typing import Optional, Union

import torch
from torch import nn, einsum
@@ -12,6 +12,7 @@ from einops import rearrange, repeat
from vector_quantize_pytorch import ResidualVQ

from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans

# helper functions

@@ -551,7 +552,7 @@ class SemanticTransformer(nn.Module):
        *,
        num_semantic_tokens,
        dim,
        wav2vec: Optional[FairseqVQWav2Vec] = None,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        **kwargs
    ):
        super().__init__()
@@ -574,7 +575,7 @@ class SemanticTransformer(nn.Module):

        if not exists(ids):
            assert exists(self.wav2vec)
            ids = self.wav2vec(raw_wave)
            ids = self.wav2vec(raw_wave, flatten = False)
            
        if return_loss:
            labels, ids = ids.clone(), ids[:, :-1]
@@ -606,7 +607,7 @@ class CoarseTransformer(nn.Module):
        codebook_size,
        num_coarse_quantizers,
        dim,
        wav2vec: Optional[FairseqVQWav2Vec] = None,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        **kwargs
    ):
        super().__init__()
@@ -840,7 +841,7 @@ class CoarseTransformerWrapper(nn.Module):
        *,
        transformer: FineTransformer,
        soundstream: Optional[SoundStream]  = None,
        wav2vec: Optional[FairseqVQWav2Vec] = None,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        num_coarse_quantize = 3
    ):
        super().__init__()
@@ -866,7 +867,7 @@ class CoarseTransformerWrapper(nn.Module):

        if not exists(semantic_token_ids):
            assert exists(self.wav2vec), 'VQWav2Vec must be be provided if given raw wave for training'
            semantic_token_ids = self.wav2vec(raw_wave)
            semantic_token_ids = self.wav2vec(raw_wave, flatten = False)

        if not exists(coarse_token_ids):
            assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'
+56 −0
Original line number Diff line number Diff line
from pathlib import Path

import torch
from torch import nn
from einops import rearrange, pack, unpack

import joblib
import fairseq

class HubertWithKmeans(nn.Module):
    def __init__(
        self,
        checkpoint_path,
        kmeans_path
    ):
        super().__init__()
        model_path = Path(checkpoint_path)
        kmeans_path = Path(kmeans_path)

        assert model_path.exists(), f'path {checkpoint_path} does not exist'
        assert kmeans_path.exists(), f'path {kmeans_path} does not exist'

        checkpoint = torch.load(checkpoint_path)
        load_model_input = {checkpoint_path: checkpoint}
        model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)

        self.model = model[0]
        self.model.eval()

        kmeans = joblib.load(kmeans_path)
        self.kmeans = kmeans

    @property
    def groups(self):
        return 1

    @property
    def codebook_size(self):
        return self.kmeans.n_clusters

    @torch.no_grad()
    def forward(self, wav_input, flatten = True):
        device = wav_input.device

        embed = self.model(wav_input, features_only = True)
        embed, packed_shape = pack([embed['x']], '* d')

        codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())

        codebook_indices = torch.from_numpy(codebook_indices).to(device).long()

        if flatten:
            return codebook_indices

        codebook_indices, = unpack(codebook_indices, packed_shape, '*')
        return codebook_indices
+3 −2
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.5',
  version = '0.0.6',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',
@@ -19,9 +19,10 @@ setup(
  ],
  install_requires=[
    'accelerate',
    'einops>=0.5',
    'einops>=0.6',
    'ema-pytorch',
    'fairseq',
    'joblib',
    'torch>=1.6',
    'vector-quantize-pytorch>=0.10.5'
  ],