Loading audiolm_pytorch/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -5,3 +5,4 @@ from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransform from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans audiolm_pytorch/audiolm_pytorch.py +8 −7 Original line number Diff line number Diff line import math from functools import partial from typing import Optional from typing import Optional, Union import torch from torch import nn, einsum Loading @@ -12,6 +12,7 @@ from einops import rearrange, repeat from vector_quantize_pytorch import ResidualVQ from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans # helper functions Loading Loading @@ -551,7 +552,7 @@ class SemanticTransformer(nn.Module): *, num_semantic_tokens, dim, wav2vec: Optional[FairseqVQWav2Vec] = None, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, **kwargs ): super().__init__() Loading @@ -574,7 +575,7 @@ class SemanticTransformer(nn.Module): if not exists(ids): assert exists(self.wav2vec) ids = self.wav2vec(raw_wave) ids = self.wav2vec(raw_wave, flatten = False) if return_loss: labels, ids = ids.clone(), ids[:, :-1] Loading Loading @@ -606,7 +607,7 @@ class CoarseTransformer(nn.Module): codebook_size, num_coarse_quantizers, dim, wav2vec: Optional[FairseqVQWav2Vec] = None, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, **kwargs ): super().__init__() Loading Loading @@ -840,7 +841,7 @@ class CoarseTransformerWrapper(nn.Module): *, transformer: FineTransformer, soundstream: Optional[SoundStream] = None, wav2vec: Optional[FairseqVQWav2Vec] = None, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, num_coarse_quantize = 3 ): super().__init__() Loading @@ -866,7 +867,7 @@ class CoarseTransformerWrapper(nn.Module): if not exists(semantic_token_ids): assert exists(self.wav2vec), 'VQWav2Vec must be be provided if given raw wave for training' semantic_token_ids = self.wav2vec(raw_wave) semantic_token_ids = self.wav2vec(raw_wave, flatten = False) if not exists(coarse_token_ids): assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training' Loading audiolm_pytorch/hubert_kmeans.py 0 → 100644 +56 −0 Original line number Diff line number Diff line from pathlib import Path import torch from torch import nn from einops import rearrange, pack, unpack import joblib import fairseq class HubertWithKmeans(nn.Module): def __init__( self, checkpoint_path, kmeans_path ): super().__init__() model_path = Path(checkpoint_path) kmeans_path = Path(kmeans_path) assert model_path.exists(), f'path {checkpoint_path} does not exist' assert kmeans_path.exists(), f'path {kmeans_path} does not exist' checkpoint = torch.load(checkpoint_path) load_model_input = {checkpoint_path: checkpoint} model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) self.model = model[0] self.model.eval() kmeans = joblib.load(kmeans_path) self.kmeans = kmeans @property def groups(self): return 1 @property def codebook_size(self): return self.kmeans.n_clusters @torch.no_grad() def forward(self, wav_input, flatten = True): device = wav_input.device embed = self.model(wav_input, features_only = True) embed, packed_shape = pack([embed['x']], '* d') codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy()) codebook_indices = torch.from_numpy(codebook_indices).to(device).long() if flatten: return codebook_indices codebook_indices, = unpack(codebook_indices, packed_shape, '*') return codebook_indices setup.py +3 −2 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.5', version = '0.0.6', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading @@ -19,9 +19,10 @@ setup( ], install_requires=[ 'accelerate', 'einops>=0.5', 'einops>=0.6', 'ema-pytorch', 'fairseq', 'joblib', 'torch>=1.6', 'vector-quantize-pytorch>=0.10.5' ], Loading Loading
audiolm_pytorch/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -5,3 +5,4 @@ from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransform from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans
audiolm_pytorch/audiolm_pytorch.py +8 −7 Original line number Diff line number Diff line import math from functools import partial from typing import Optional from typing import Optional, Union import torch from torch import nn, einsum Loading @@ -12,6 +12,7 @@ from einops import rearrange, repeat from vector_quantize_pytorch import ResidualVQ from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans # helper functions Loading Loading @@ -551,7 +552,7 @@ class SemanticTransformer(nn.Module): *, num_semantic_tokens, dim, wav2vec: Optional[FairseqVQWav2Vec] = None, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, **kwargs ): super().__init__() Loading @@ -574,7 +575,7 @@ class SemanticTransformer(nn.Module): if not exists(ids): assert exists(self.wav2vec) ids = self.wav2vec(raw_wave) ids = self.wav2vec(raw_wave, flatten = False) if return_loss: labels, ids = ids.clone(), ids[:, :-1] Loading Loading @@ -606,7 +607,7 @@ class CoarseTransformer(nn.Module): codebook_size, num_coarse_quantizers, dim, wav2vec: Optional[FairseqVQWav2Vec] = None, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, **kwargs ): super().__init__() Loading Loading @@ -840,7 +841,7 @@ class CoarseTransformerWrapper(nn.Module): *, transformer: FineTransformer, soundstream: Optional[SoundStream] = None, wav2vec: Optional[FairseqVQWav2Vec] = None, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None, num_coarse_quantize = 3 ): super().__init__() Loading @@ -866,7 +867,7 @@ class CoarseTransformerWrapper(nn.Module): if not exists(semantic_token_ids): assert exists(self.wav2vec), 'VQWav2Vec must be be provided if given raw wave for training' semantic_token_ids = self.wav2vec(raw_wave) semantic_token_ids = self.wav2vec(raw_wave, flatten = False) if not exists(coarse_token_ids): assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training' Loading
audiolm_pytorch/hubert_kmeans.py 0 → 100644 +56 −0 Original line number Diff line number Diff line from pathlib import Path import torch from torch import nn from einops import rearrange, pack, unpack import joblib import fairseq class HubertWithKmeans(nn.Module): def __init__( self, checkpoint_path, kmeans_path ): super().__init__() model_path = Path(checkpoint_path) kmeans_path = Path(kmeans_path) assert model_path.exists(), f'path {checkpoint_path} does not exist' assert kmeans_path.exists(), f'path {kmeans_path} does not exist' checkpoint = torch.load(checkpoint_path) load_model_input = {checkpoint_path: checkpoint} model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) self.model = model[0] self.model.eval() kmeans = joblib.load(kmeans_path) self.kmeans = kmeans @property def groups(self): return 1 @property def codebook_size(self): return self.kmeans.n_clusters @torch.no_grad() def forward(self, wav_input, flatten = True): device = wav_input.device embed = self.model(wav_input, features_only = True) embed, packed_shape = pack([embed['x']], '* d') codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy()) codebook_indices = torch.from_numpy(codebook_indices).to(device).long() if flatten: return codebook_indices codebook_indices, = unpack(codebook_indices, packed_shape, '*') return codebook_indices
setup.py +3 −2 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.5', version = '0.0.6', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading @@ -19,9 +19,10 @@ setup( ], install_requires=[ 'accelerate', 'einops>=0.5', 'einops>=0.6', 'ema-pytorch', 'fairseq', 'joblib', 'torch>=1.6', 'vector-quantize-pytorch>=0.10.5' ], Loading