Loading README.md +10 −0 Original line number Diff line number Diff line Loading @@ -107,6 +107,16 @@ loss = coarse_wrapper( ) loss.backward() # after a lot of training mock_semantic_token_ids = torch.randint(0, wav2vec.codebook_size, (1, 128)) coarse_tokens = coarse_wrapper.generate( semantic_token_ids = mock_semantic_token_ids, max_time_steps = 512 ) # (1, 512, 3) - (batch, time steps, num quantizers) ``` ex. `FineTransformer` Loading audiolm_pytorch/audiolm_pytorch.py +65 −3 Original line number Diff line number Diff line import math from functools import partial from typing import Optional, Union from typeguard import typechecked import torch from torch import nn, einsum Loading Loading @@ -305,6 +307,7 @@ class Transformer(nn.Module): # the three hierarchical transformers @typechecked class SemanticTransformer(nn.Module): def __init__( self, Loading Loading @@ -518,6 +521,7 @@ class SemanticTransformer(nn.Module): return loss @typechecked class CoarseTransformer(nn.Module): def __init__( self, Loading Loading @@ -589,7 +593,8 @@ class CoarseTransformer(nn.Module): self_attn_mask = None, text = None, text_embeds = None, cond_drop_prob = None cond_drop_prob = None, return_only_coarse_logits = False ): b, device = semantic_token_ids.shape[0], semantic_token_ids.device Loading Loading @@ -638,7 +643,7 @@ class CoarseTransformer(nn.Module): # semantic logits semantic_logits = self.to_semantic_logits(pred_semantic_tokens) semantic_logits = self.to_semantic_logits(pred_semantic_tokens) if not return_only_coarse_logits else None # get coarse logits Loading Loading @@ -821,7 +826,7 @@ class FineTransformer(nn.Module): # training wrappers @typechecked class CoarseTransformerWrapper(nn.Module): def __init__( self, Loading @@ -841,6 +846,61 @@ class CoarseTransformerWrapper(nn.Module): self.pad_id = pad_id self.num_coarse_quantizers = transformer.num_coarse_quantizers self.eos_id = transformer.coarse_eos_id @property def device(self): return next(self.parameters()).device @eval_decorator @torch.no_grad() def generate( self, *, semantic_token_ids, max_time_steps = 512, cond_scale = 3., filter_thres = 0.9, temperature = 1., reshape_output = True ): batch, device = semantic_token_ids.shape[0], self.device semantic_token_ids = semantic_token_ids.to(device) coarse_token_ids = torch.empty((batch, 0), device = device, dtype = torch.long) init_coarse_time_step = coarse_token_ids.shape[-1] output = coarse_token_ids.clone() for time_step in range(init_coarse_time_step, max_time_steps): for ind in range(self.num_coarse_quantizers): is_last_step = ind == (self.num_coarse_quantizers - 1) _, coarse_logits = self.transformer.forward_with_cond_scale( coarse_token_ids = coarse_token_ids, semantic_token_ids = semantic_token_ids, cond_scale = cond_scale, return_only_coarse_logits = True ) last_coarse_logits = coarse_logits[:, -1] if not is_last_step: last_coarse_logits[:, -1] = float('-inf') # prevent from eos if not last quantizer step, but move this to masking logic within the transformer at some point, for both training and eval filtered_logits = top_k(last_coarse_logits, thres = filter_thres) sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) sampled = rearrange(sampled, 'b -> b 1') output = torch.cat((output, sampled), dim = -1) output = mask_out_after_eos_id(output, self.eos_id, include_eos = False) if reshape_output: output = rearrange(output, 'b (n q) -> b n q', q = self.num_coarse_quantizers) return output def forward( self, Loading Loading @@ -918,6 +978,7 @@ class CoarseTransformerWrapper(nn.Module): return (semantic_loss * num_semantic_logits + coarse_loss * num_coarse_logits) / (num_semantic_logits + num_coarse_logits) @typechecked class FineTransformerWrapper(nn.Module): def __init__( self, Loading Loading @@ -999,6 +1060,7 @@ class FineTransformerWrapper(nn.Module): # audio LM @typechecked class AudioLM(nn.Module): def __init__( self, Loading setup.py +2 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.39', version = '0.0.40', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading @@ -26,6 +26,7 @@ setup( 'torch>=1.6', 'torchaudio', 'transformers', 'typeguard', 'vector-quantize-pytorch>=0.10.10' ], classifiers=[ Loading Loading
README.md +10 −0 Original line number Diff line number Diff line Loading @@ -107,6 +107,16 @@ loss = coarse_wrapper( ) loss.backward() # after a lot of training mock_semantic_token_ids = torch.randint(0, wav2vec.codebook_size, (1, 128)) coarse_tokens = coarse_wrapper.generate( semantic_token_ids = mock_semantic_token_ids, max_time_steps = 512 ) # (1, 512, 3) - (batch, time steps, num quantizers) ``` ex. `FineTransformer` Loading
audiolm_pytorch/audiolm_pytorch.py +65 −3 Original line number Diff line number Diff line import math from functools import partial from typing import Optional, Union from typeguard import typechecked import torch from torch import nn, einsum Loading Loading @@ -305,6 +307,7 @@ class Transformer(nn.Module): # the three hierarchical transformers @typechecked class SemanticTransformer(nn.Module): def __init__( self, Loading Loading @@ -518,6 +521,7 @@ class SemanticTransformer(nn.Module): return loss @typechecked class CoarseTransformer(nn.Module): def __init__( self, Loading Loading @@ -589,7 +593,8 @@ class CoarseTransformer(nn.Module): self_attn_mask = None, text = None, text_embeds = None, cond_drop_prob = None cond_drop_prob = None, return_only_coarse_logits = False ): b, device = semantic_token_ids.shape[0], semantic_token_ids.device Loading Loading @@ -638,7 +643,7 @@ class CoarseTransformer(nn.Module): # semantic logits semantic_logits = self.to_semantic_logits(pred_semantic_tokens) semantic_logits = self.to_semantic_logits(pred_semantic_tokens) if not return_only_coarse_logits else None # get coarse logits Loading Loading @@ -821,7 +826,7 @@ class FineTransformer(nn.Module): # training wrappers @typechecked class CoarseTransformerWrapper(nn.Module): def __init__( self, Loading @@ -841,6 +846,61 @@ class CoarseTransformerWrapper(nn.Module): self.pad_id = pad_id self.num_coarse_quantizers = transformer.num_coarse_quantizers self.eos_id = transformer.coarse_eos_id @property def device(self): return next(self.parameters()).device @eval_decorator @torch.no_grad() def generate( self, *, semantic_token_ids, max_time_steps = 512, cond_scale = 3., filter_thres = 0.9, temperature = 1., reshape_output = True ): batch, device = semantic_token_ids.shape[0], self.device semantic_token_ids = semantic_token_ids.to(device) coarse_token_ids = torch.empty((batch, 0), device = device, dtype = torch.long) init_coarse_time_step = coarse_token_ids.shape[-1] output = coarse_token_ids.clone() for time_step in range(init_coarse_time_step, max_time_steps): for ind in range(self.num_coarse_quantizers): is_last_step = ind == (self.num_coarse_quantizers - 1) _, coarse_logits = self.transformer.forward_with_cond_scale( coarse_token_ids = coarse_token_ids, semantic_token_ids = semantic_token_ids, cond_scale = cond_scale, return_only_coarse_logits = True ) last_coarse_logits = coarse_logits[:, -1] if not is_last_step: last_coarse_logits[:, -1] = float('-inf') # prevent from eos if not last quantizer step, but move this to masking logic within the transformer at some point, for both training and eval filtered_logits = top_k(last_coarse_logits, thres = filter_thres) sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) sampled = rearrange(sampled, 'b -> b 1') output = torch.cat((output, sampled), dim = -1) output = mask_out_after_eos_id(output, self.eos_id, include_eos = False) if reshape_output: output = rearrange(output, 'b (n q) -> b n q', q = self.num_coarse_quantizers) return output def forward( self, Loading Loading @@ -918,6 +978,7 @@ class CoarseTransformerWrapper(nn.Module): return (semantic_loss * num_semantic_logits + coarse_loss * num_coarse_logits) / (num_semantic_logits + num_coarse_logits) @typechecked class FineTransformerWrapper(nn.Module): def __init__( self, Loading Loading @@ -999,6 +1060,7 @@ class FineTransformerWrapper(nn.Module): # audio LM @typechecked class AudioLM(nn.Module): def __init__( self, Loading
setup.py +2 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.39', version = '0.0.40', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading @@ -26,6 +26,7 @@ setup( 'torch>=1.6', 'torchaudio', 'transformers', 'typeguard', 'vector-quantize-pytorch>=0.10.10' ], classifiers=[ Loading