Loading README.md +101 −3 Original line number Diff line number Diff line Loading @@ -14,15 +14,110 @@ $ pip install audiolm-pytorch ## Usage First, `SoundStream` needs to be trained on a large corpus of audio data ```python from audiolm_pytorch import SoundStream, SoundStreamTrainer soundstream = SoundStream( codebook_size = 1024, rq_num_quantizers = 8, ) trainer = SoundStreamTrainer( soundstream, folder = '/path/to/librispeech', batch_size = 4, data_max_length = 320 * 32, num_train_steps = 10000 ).cuda() trainer.train() ``` Then three separate transformers (`SemanticTransformer`, `CoarseTransformer`, `FineTransformer`) need to be trained ex. `SemanticTransformer` ```python import torch from audiolm_pytorch import HubertWithKmeans, SemanticTransformer wav2vec = HubertWithKmeans( checkpoint_path = './hubert/hubert_base_ls960.pt', kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin' ) semantic_transformer = SemanticTransformer( wav2vec = wav2vec, dim = 1024, depth = 6 ).cuda() wave = torch.randn(1, 320 * 512).cuda() loss = semantic_transformer( raw_wave = wave, return_loss = True ) loss.backward() ``` ex. `CoarseTransformer` ```python import torch from audiolm_pytorch import HubertWithKmeans, SoundStream, CoarseTransformer, CoarseTransformerWrapper wav2vec = HubertWithKmeans( checkpoint_path = './hubert/hubert_base_ls960.pt', kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin' ) soundstream = SoundStream( codebook_size = 1024, rq_num_quantizers = 8, ) coarse_transformer = CoarseTransformer( wav2vec = wav2vec, codebook_size = 1024, num_coarse_quantizers = 3, dim = 512, depth = 6 ) coarse_wrapper = CoarseTransformerWrapper( wav2vec = wav2vec, soundstream = soundstream, transformer = coarse_transformer ).cuda() wave = torch.randn(1, 32 * 320).cuda() loss = coarse_wrapper( raw_wave = wave, return_loss = True ) loss.backward() ``` ex. `FineTransformer` ```python import torch from audiolm_pytorch.audiolm_pytorch import SoundStream, AudioLM, FineTransformer, FineTransformerWrapper from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerWrapper soundstream = SoundStream( codebook_size = 1024, rq_num_quantizers = 8, ) soundstream.load('/path/to/trained/soundstream.pt') transformer = FineTransformer( num_coarse_quantizers = 3, num_fine_quantizers = 5, Loading @@ -36,16 +131,18 @@ train_wrapper = FineTransformerWrapper( transformer = transformer ).cuda() raw_waveform = torch.randn(1, 320 * 512).cuda() wave = torch.randn(1, 320 * 512).cuda() loss = train_wrapper( raw_wave = raw_waveform, raw_wave = wave, return_loss = True ) loss.backward() ``` - [ ] show how to generate from prompt tensor or file ## Appreciation - <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research Loading Loading @@ -79,6 +176,7 @@ loss.backward() - [ ] abstract out conditioning + classifier free guidance into external module or potentially a package - [ ] add option to use flash attention - [ ] function for pretty printing all discriminator losses to log - [ ] simplify training even more within AudioLM class ## Citations Loading audiolm_pytorch/audiolm_pytorch.py +14 −2 Original line number Diff line number Diff line Loading @@ -267,8 +267,8 @@ class SemanticTransformer(nn.Module): def __init__( self, *, num_semantic_tokens, dim, num_semantic_tokens = None, t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -279,6 +279,12 @@ class SemanticTransformer(nn.Module): **kwargs ): super().__init__() assert exists(wav2vec) or exists(num_semantic_tokens) if exists(wav2vec): num_semantic_tokens = default(num_semantic_tokens, wav2vec.codebook_size) assert num_semantic_tokens == wav2vec.codebook_size self.has_condition = has_condition self.embed_text = partial(t5_encode_text, name = t5_name) self.cond_drop_prob = cond_drop_prob Loading Loading @@ -362,10 +368,10 @@ class CoarseTransformer(nn.Module): def __init__( self, *, num_semantic_tokens, codebook_size, num_coarse_quantizers, dim, num_semantic_tokens = None, t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -374,6 +380,12 @@ class CoarseTransformer(nn.Module): **kwargs ): super().__init__() assert exists(wav2vec) or exists(num_semantic_tokens) if exists(wav2vec): num_semantic_tokens = default(num_semantic_tokens, wav2vec.codebook_size) assert num_semantic_tokens == wav2vec.codebook_size self.has_condition = has_condition self.embed_text = partial(t5_encode_text, name = t5_name) self.cond_drop_prob = cond_drop_prob Loading audiolm_pytorch/soundstream.py +6 −0 Original line number Diff line number Diff line import functools from pathlib import Path from functools import partial import torch Loading Loading @@ -309,6 +310,11 @@ class SoundStream(nn.Module): self.adversarial_loss_weight = adversarial_loss_weight self.feature_loss_weight = feature_loss_weight def load(self, path): path = Path(path) assert path.exists() self.load_state_dict(torch.load(str(path))) def non_discr_parameters(self): return [*self.encoder.parameters(), *self.decoder.parameters()] Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.26', version = '0.0.27', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
README.md +101 −3 Original line number Diff line number Diff line Loading @@ -14,15 +14,110 @@ $ pip install audiolm-pytorch ## Usage First, `SoundStream` needs to be trained on a large corpus of audio data ```python from audiolm_pytorch import SoundStream, SoundStreamTrainer soundstream = SoundStream( codebook_size = 1024, rq_num_quantizers = 8, ) trainer = SoundStreamTrainer( soundstream, folder = '/path/to/librispeech', batch_size = 4, data_max_length = 320 * 32, num_train_steps = 10000 ).cuda() trainer.train() ``` Then three separate transformers (`SemanticTransformer`, `CoarseTransformer`, `FineTransformer`) need to be trained ex. `SemanticTransformer` ```python import torch from audiolm_pytorch import HubertWithKmeans, SemanticTransformer wav2vec = HubertWithKmeans( checkpoint_path = './hubert/hubert_base_ls960.pt', kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin' ) semantic_transformer = SemanticTransformer( wav2vec = wav2vec, dim = 1024, depth = 6 ).cuda() wave = torch.randn(1, 320 * 512).cuda() loss = semantic_transformer( raw_wave = wave, return_loss = True ) loss.backward() ``` ex. `CoarseTransformer` ```python import torch from audiolm_pytorch import HubertWithKmeans, SoundStream, CoarseTransformer, CoarseTransformerWrapper wav2vec = HubertWithKmeans( checkpoint_path = './hubert/hubert_base_ls960.pt', kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin' ) soundstream = SoundStream( codebook_size = 1024, rq_num_quantizers = 8, ) coarse_transformer = CoarseTransformer( wav2vec = wav2vec, codebook_size = 1024, num_coarse_quantizers = 3, dim = 512, depth = 6 ) coarse_wrapper = CoarseTransformerWrapper( wav2vec = wav2vec, soundstream = soundstream, transformer = coarse_transformer ).cuda() wave = torch.randn(1, 32 * 320).cuda() loss = coarse_wrapper( raw_wave = wave, return_loss = True ) loss.backward() ``` ex. `FineTransformer` ```python import torch from audiolm_pytorch.audiolm_pytorch import SoundStream, AudioLM, FineTransformer, FineTransformerWrapper from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerWrapper soundstream = SoundStream( codebook_size = 1024, rq_num_quantizers = 8, ) soundstream.load('/path/to/trained/soundstream.pt') transformer = FineTransformer( num_coarse_quantizers = 3, num_fine_quantizers = 5, Loading @@ -36,16 +131,18 @@ train_wrapper = FineTransformerWrapper( transformer = transformer ).cuda() raw_waveform = torch.randn(1, 320 * 512).cuda() wave = torch.randn(1, 320 * 512).cuda() loss = train_wrapper( raw_wave = raw_waveform, raw_wave = wave, return_loss = True ) loss.backward() ``` - [ ] show how to generate from prompt tensor or file ## Appreciation - <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research Loading Loading @@ -79,6 +176,7 @@ loss.backward() - [ ] abstract out conditioning + classifier free guidance into external module or potentially a package - [ ] add option to use flash attention - [ ] function for pretty printing all discriminator losses to log - [ ] simplify training even more within AudioLM class ## Citations Loading
audiolm_pytorch/audiolm_pytorch.py +14 −2 Original line number Diff line number Diff line Loading @@ -267,8 +267,8 @@ class SemanticTransformer(nn.Module): def __init__( self, *, num_semantic_tokens, dim, num_semantic_tokens = None, t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -279,6 +279,12 @@ class SemanticTransformer(nn.Module): **kwargs ): super().__init__() assert exists(wav2vec) or exists(num_semantic_tokens) if exists(wav2vec): num_semantic_tokens = default(num_semantic_tokens, wav2vec.codebook_size) assert num_semantic_tokens == wav2vec.codebook_size self.has_condition = has_condition self.embed_text = partial(t5_encode_text, name = t5_name) self.cond_drop_prob = cond_drop_prob Loading Loading @@ -362,10 +368,10 @@ class CoarseTransformer(nn.Module): def __init__( self, *, num_semantic_tokens, codebook_size, num_coarse_quantizers, dim, num_semantic_tokens = None, t5_name = DEFAULT_T5_NAME, has_condition = False, cond_drop_prob = 0.5, Loading @@ -374,6 +380,12 @@ class CoarseTransformer(nn.Module): **kwargs ): super().__init__() assert exists(wav2vec) or exists(num_semantic_tokens) if exists(wav2vec): num_semantic_tokens = default(num_semantic_tokens, wav2vec.codebook_size) assert num_semantic_tokens == wav2vec.codebook_size self.has_condition = has_condition self.embed_text = partial(t5_encode_text, name = t5_name) self.cond_drop_prob = cond_drop_prob Loading
audiolm_pytorch/soundstream.py +6 −0 Original line number Diff line number Diff line import functools from pathlib import Path from functools import partial import torch Loading Loading @@ -309,6 +310,11 @@ class SoundStream(nn.Module): self.adversarial_loss_weight = adversarial_loss_weight self.feature_loss_weight = feature_loss_weight def load(self, path): path = Path(path) assert path.exists() self.load_state_dict(torch.load(str(path))) def non_discr_parameters(self): return [*self.encoder.parameters(), *self.decoder.parameters()] Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.26', version = '0.0.27', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading