Loading README.md +24 −1 Original line number Diff line number Diff line Loading @@ -155,7 +155,29 @@ loss = train_wrapper( loss.backward() ``` - [ ] show how to generate from prompt tensor or file All together now ```python audiolm = AudioLM( wav2vec = wav2vec, soundstream = soundstream, semantic_transformer = semantic_transformer, coarse_transformer = coarse_transformer, fine_transformer = transformer ) generated_wav = audiolm(batch_size = 1) # or with priming generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8)) # or with text condition, if given generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the distant echos of bells']) ``` ## Appreciation Loading Loading @@ -192,6 +214,7 @@ loss.backward() - [ ] abstract out conditioning + classifier free guidance into external module or potentially a package - [ ] add option to use flash attention - [ ] simplify training even more within AudioLM class - [ ] cli tool, something like `audiolm generate <wav.file | text>` and save generated wav file to local directory ## Citations Loading audiolm_pytorch/audiolm_pytorch.py +206 −41 Original line number Diff line number Diff line import math from functools import partial from typing import Optional, Union from typing import Optional, Union, List from typeguard import typechecked import torch Loading @@ -21,6 +21,8 @@ from torchaudio.functional import resample from audiolm_pytorch.soundstream import SoundStream from tqdm import tqdm # helper functions def exists(val): Loading Loading @@ -356,7 +358,7 @@ class SemanticTransformer(nn.Module): self, *, max_length, text = None, text: Optional[List[str]] = None, text_embeds = None, prime_wave = None, prime_ids = None, Loading Loading @@ -396,17 +398,17 @@ class SemanticTransformer(nn.Module): batch = ids.shape[0] start_length = ids.shape[-1] output = ids.clone() sample_semantic_ids = ids.clone() batch_range = rearrange(torch.arange(batch, device = device), 'b -> b 1') last_logit_indices = (ids != self.pad_id).sum(dim = -1).long() # sample from transformer for ind in range(start_length, max_length): for ind in tqdm(range(start_length, max_length), desc = 'generating semantic'): logits = self.forward_with_cond_scale( ids = output, ids = sample_semantic_ids, text_embeds = text_embeds, **kwargs ) Loading @@ -419,18 +421,18 @@ class SemanticTransformer(nn.Module): sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) sampled = rearrange(sampled, 'b -> b 1') output = torch.cat((output, sampled), dim = -1) sample_semantic_ids = torch.cat((sample_semantic_ids, sampled), dim = -1) if all_rows_have_eos_id(output, self.eos_id): if all_rows_have_eos_id(sample_semantic_ids, self.eos_id): break last_logit_indices += 1 output = mask_out_after_eos_id(output, self.pad_id, include_eos = include_eos_in_output) sample_semantic_ids = mask_out_after_eos_id(sample_semantic_ids, self.pad_id, include_eos = include_eos_in_output) # ensure all sequences have eos has_eos_mask = (output == self.eos_id).any(dim = -1) has_eos_mask = (sample_semantic_ids == self.eos_id).any(dim = -1) if not has_eos_mask.all(): append_eos_or_pad = torch.where( Loading @@ -439,9 +441,9 @@ class SemanticTransformer(nn.Module): torch.full((batch, 1), self.eos_id, dtype = torch.long, device = device), ) output = torch.cat((output, append_eos_or_pad), dim = -1) sample_semantic_ids = torch.cat((sample_semantic_ids, append_eos_or_pad), dim = -1) return output return sample_semantic_ids def forward_with_cond_scale( self, Loading @@ -463,7 +465,7 @@ class SemanticTransformer(nn.Module): raw_wave = None, ids = None, return_loss = False, text = None, text: Optional[List[str]] = None, text_embeds = None, cond_drop_prob = None ): Loading Loading @@ -577,13 +579,19 @@ class CoarseTransformer(nn.Module): cond_scale = 3, **kwargs ): logits = self.forward(*args, cond_drop_prob = 0., **kwargs) semantic_logits, coarse_logits = self.forward(*args, cond_drop_prob = 0., **kwargs) if cond_scale == 1 or not self.has_condition: return logits return semantic_logits, coarse_logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale null_semantic_logits, null_coarse_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) scaled_semantic_logits = None if exists(null_semantic_logits): scaled_semantic_logits = null_semantic_logits + (semantic_logits - null_semantic_logits) * cond_scale scaled_coarse_logits = null_coarse_logits + (coarse_logits - null_coarse_logits) * cond_scale return scaled_semantic_logits, scaled_coarse_logits def forward( self, Loading @@ -591,7 +599,7 @@ class CoarseTransformer(nn.Module): semantic_token_ids, coarse_token_ids, self_attn_mask = None, text = None, text: Optional[List[str]] = None, text_embeds = None, cond_drop_prob = None, return_only_coarse_logits = False Loading Loading @@ -719,22 +727,29 @@ class FineTransformer(nn.Module): cond_scale = 3, **kwargs ): logits = self.forward(*args, cond_drop_prob = 0., **kwargs) coarse_logits, fine_logits = self.forward(*args, cond_drop_prob = 0., **kwargs) if cond_scale == 1 or not self.has_condition: return logits return coarse_logits, fine_logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale null_coarse_logits, null_fine_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) scaled_coarse_logits = None if exists(null_coarse_logits): scaled_coarse_logits = null_coarse_logits + (coarse_logits - null_coarse_logits) * cond_scale scaled_fine_logits = null_fine_logits + (fine_logits - null_fine_logits) * cond_scale return scaled_coarse_logits, scaled_fine_logits def forward( self, coarse_token_ids, fine_token_ids, text = None, text: Optional[List[str]] = None, text_embeds = None, cond_drop_prob = None, self_attn_mask = None self_attn_mask = None, return_only_fine_logits = False ): b, device = coarse_token_ids.shape[0], coarse_token_ids.device has_text = exists(text) or exists(text_embeds) Loading Loading @@ -794,6 +809,9 @@ class FineTransformer(nn.Module): pred_coarse_tokens = rearrange(pred_coarse_tokens, 'b (n q) d -> b n q d', q = self.num_coarse_quantizers) coarse_logits = None if not return_only_fine_logits: coarse_logits = einsum('q c d, b n q d -> b n q c', self.coarse_logit_weights, pred_coarse_tokens) coarse_logits = rearrange(coarse_logits, 'b n q c -> b (n q) c') Loading Loading @@ -858,12 +876,15 @@ class CoarseTransformerWrapper(nn.Module): self, *, semantic_token_ids, text: Optional[List[str]] = None, text_embeds = None, max_time_steps = 512, cond_scale = 3., filter_thres = 0.9, temperature = 1., reshape_output = True, reconstruct_wave = False reconstruct_wave = False, **kwargs ): batch, device = semantic_token_ids.shape[0], self.device Loading @@ -871,18 +892,31 @@ class CoarseTransformerWrapper(nn.Module): coarse_token_ids = torch.empty((batch, 0), device = device, dtype = torch.long) # derive text embeddings if needed has_text = exists(text) or exists(text_embeds) assert not (self.transformer.has_condition ^ has_text) if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.transformer.embed_text(text, output_device = device) # initialize init_coarse_time_step = coarse_token_ids.shape[-1] output = coarse_token_ids.clone() sampled_coarse_token_ids = coarse_token_ids.clone() for time_step in range(init_coarse_time_step, max_time_steps): for time_step in tqdm(range(init_coarse_time_step, max_time_steps), desc = 'generating coarse'): for ind in range(self.num_coarse_quantizers): is_last_step = ind == (self.num_coarse_quantizers - 1) _, coarse_logits = self.transformer.forward_with_cond_scale( coarse_token_ids = coarse_token_ids, semantic_token_ids = semantic_token_ids, text_embeds = text_embeds, cond_scale = cond_scale, return_only_coarse_logits = True return_only_coarse_logits = True, **kwargs ) last_coarse_logits = coarse_logits[:, -1] Loading @@ -894,20 +928,20 @@ class CoarseTransformerWrapper(nn.Module): sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) sampled = rearrange(sampled, 'b -> b 1') output = torch.cat((output, sampled), dim = -1) sampled_coarse_token_ids = torch.cat((sampled_coarse_token_ids, sampled), dim = -1) output = mask_out_after_eos_id(output, self.eos_id, include_eos = False) sampled_coarse_token_ids = mask_out_after_eos_id(sampled_coarse_token_ids, self.eos_id, include_eos = False) if reshape_output or reconstruct_wave: output = rearrange(output, 'b (n q) -> b n q', q = self.num_coarse_quantizers) sampled_coarse_token_ids = rearrange(sampled_coarse_token_ids, 'b (n q) -> b n q', q = self.num_coarse_quantizers) if reconstruct_wave: assert exists(self.soundstream) wav = self.soundstream.decode_from_codebook_indices(output) wav = self.soundstream.decode_from_codebook_indices(sampled_coarse_token_ids) wav = rearrange(wav, 'b 1 n -> b n') return wav return output return sampled_coarse_token_ids def forward( self, Loading Loading @@ -998,11 +1032,98 @@ class FineTransformerWrapper(nn.Module): self.soundstream = soundstream self.transformer = transformer self.num_fine_quantizers = transformer.num_fine_quantizers self.num_coarse_quantizers = transformer.num_coarse_quantizers self.eos_id = transformer.eos_id assert self.num_coarse_quantizers > 0 self.pad_id = pad_id @property def device(self): return next(self.parameters()).device @eval_decorator @torch.no_grad() def generate( self, *, coarse_token_ids, text: Optional[List[str]] = None, text_embeds = None, cond_scale = 3., filter_thres = 0.9, temperature = 1., reshape_output = True, reconstruct_wave = False, **kwargs ): coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)') batch, device = coarse_token_ids.shape[0], self.device coarse_token_ids = coarse_token_ids.to(device) # derive text embeddings if needed has_text = exists(text) or exists(text_embeds) assert not (self.transformer.has_condition ^ has_text) if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.transformer.embed_text(text, output_device = device) # initialize fine_token_ids = torch.empty((batch, 0), device = device, dtype = torch.long) init_fine_time_step = fine_token_ids.shape[-1] max_time_steps = coarse_token_ids.shape[1] // self.num_coarse_quantizers sampled_fine_token_ids = fine_token_ids.clone() for time_step in tqdm(range(init_fine_time_step, max_time_steps), desc = 'generating fine'): for ind in range(self.num_fine_quantizers): is_last_step = ind == (self.num_fine_quantizers - 1) _, fine_logits = self.transformer.forward_with_cond_scale( coarse_token_ids = coarse_token_ids, fine_token_ids = fine_token_ids, text_embeds = text_embeds, cond_scale = cond_scale, return_only_fine_logits = True, **kwargs ) last_fine_logits = fine_logits[:, -1] if not is_last_step: last_fine_logits[:, -1] = float('-inf') # prevent from eos if not last quantizer step, but move this to masking logic within the transformer at some point, for both training and eval filtered_logits = top_k(last_fine_logits, thres = filter_thres) sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) sampled = rearrange(sampled, 'b -> b 1') sampled_fine_token_ids = torch.cat((sampled_fine_token_ids, sampled), dim = -1) sampled_fine_token_ids = mask_out_after_eos_id(sampled_fine_token_ids, self.eos_id, include_eos = False) if reshape_output or reconstruct_wave: sampled_fine_token_ids = rearrange(sampled_fine_token_ids, 'b (n q) -> b n q', q = self.num_fine_quantizers) if reconstruct_wave: assert exists(self.soundstream) coarse_token_ids = rearrange(coarse_token_ids, 'b (n q) -> b n q', q = self.num_coarse_quantizers) coarse_and_fine_ids = torch.cat((coarse_token_ids, sampled_fine_token_ids), dim = -1) wav = self.soundstream.decode_from_codebook_indices(coarse_and_fine_ids) wav = rearrange(wav, 'b 1 n -> b n') return wav return sampled_fine_token_ids def forward( self, *, Loading Loading @@ -1072,16 +1193,60 @@ class AudioLM(nn.Module): def __init__( self, *, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]], soundstream: SoundStream, semantic_transformer: SemanticTransformer, coarse_transformer: CoarseTransformer, fine_transformer: FineTransformer, fine_transformer: FineTransformer ): super().__init__() self.soundstream = soundstream self.semantic = semantic_transformer self.coarse = coarse_transformer self.fine = fine_transformer def forward(self, x): raise NotImplemented self.coarse = CoarseTransformerWrapper( wav2vec = wav2vec, soundstream = soundstream, transformer = coarse_transformer, unique_consecutive = semantic_transformer.unique_consecutive ) self.fine = FineTransformerWrapper( soundstream = soundstream, transformer = fine_transformer ) @property def device(self): return next(self.parameters()).device @eval_decorator @torch.no_grad() def forward( self, *, batch_size = 1, text: Optional[List[str]] = None, prime_wave = None, max_length = 2048 ): if exists(prime_wave): prime_wave = prime_wave.to(self.device) semantic_token_ids = self.semantic.generate( text = text, batch_size = batch_size, prime_wave = prime_wave, max_length = max_length ) coarse_token_ids = self.coarse.generate( text = text, semantic_token_ids = semantic_token_ids ) generated_wave = self.fine.generate( text = text, coarse_token_ids = coarse_token_ids, reconstruct_wave = True ) return generated_wave setup.py +2 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.41', version = '0.0.42', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading @@ -26,6 +26,7 @@ setup( 'torch>=1.6', 'torchaudio', 'transformers', 'tqdm', 'typeguard', 'vector-quantize-pytorch>=0.10.11' ], Loading Loading
README.md +24 −1 Original line number Diff line number Diff line Loading @@ -155,7 +155,29 @@ loss = train_wrapper( loss.backward() ``` - [ ] show how to generate from prompt tensor or file All together now ```python audiolm = AudioLM( wav2vec = wav2vec, soundstream = soundstream, semantic_transformer = semantic_transformer, coarse_transformer = coarse_transformer, fine_transformer = transformer ) generated_wav = audiolm(batch_size = 1) # or with priming generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8)) # or with text condition, if given generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the distant echos of bells']) ``` ## Appreciation Loading Loading @@ -192,6 +214,7 @@ loss.backward() - [ ] abstract out conditioning + classifier free guidance into external module or potentially a package - [ ] add option to use flash attention - [ ] simplify training even more within AudioLM class - [ ] cli tool, something like `audiolm generate <wav.file | text>` and save generated wav file to local directory ## Citations Loading
audiolm_pytorch/audiolm_pytorch.py +206 −41 Original line number Diff line number Diff line import math from functools import partial from typing import Optional, Union from typing import Optional, Union, List from typeguard import typechecked import torch Loading @@ -21,6 +21,8 @@ from torchaudio.functional import resample from audiolm_pytorch.soundstream import SoundStream from tqdm import tqdm # helper functions def exists(val): Loading Loading @@ -356,7 +358,7 @@ class SemanticTransformer(nn.Module): self, *, max_length, text = None, text: Optional[List[str]] = None, text_embeds = None, prime_wave = None, prime_ids = None, Loading Loading @@ -396,17 +398,17 @@ class SemanticTransformer(nn.Module): batch = ids.shape[0] start_length = ids.shape[-1] output = ids.clone() sample_semantic_ids = ids.clone() batch_range = rearrange(torch.arange(batch, device = device), 'b -> b 1') last_logit_indices = (ids != self.pad_id).sum(dim = -1).long() # sample from transformer for ind in range(start_length, max_length): for ind in tqdm(range(start_length, max_length), desc = 'generating semantic'): logits = self.forward_with_cond_scale( ids = output, ids = sample_semantic_ids, text_embeds = text_embeds, **kwargs ) Loading @@ -419,18 +421,18 @@ class SemanticTransformer(nn.Module): sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) sampled = rearrange(sampled, 'b -> b 1') output = torch.cat((output, sampled), dim = -1) sample_semantic_ids = torch.cat((sample_semantic_ids, sampled), dim = -1) if all_rows_have_eos_id(output, self.eos_id): if all_rows_have_eos_id(sample_semantic_ids, self.eos_id): break last_logit_indices += 1 output = mask_out_after_eos_id(output, self.pad_id, include_eos = include_eos_in_output) sample_semantic_ids = mask_out_after_eos_id(sample_semantic_ids, self.pad_id, include_eos = include_eos_in_output) # ensure all sequences have eos has_eos_mask = (output == self.eos_id).any(dim = -1) has_eos_mask = (sample_semantic_ids == self.eos_id).any(dim = -1) if not has_eos_mask.all(): append_eos_or_pad = torch.where( Loading @@ -439,9 +441,9 @@ class SemanticTransformer(nn.Module): torch.full((batch, 1), self.eos_id, dtype = torch.long, device = device), ) output = torch.cat((output, append_eos_or_pad), dim = -1) sample_semantic_ids = torch.cat((sample_semantic_ids, append_eos_or_pad), dim = -1) return output return sample_semantic_ids def forward_with_cond_scale( self, Loading @@ -463,7 +465,7 @@ class SemanticTransformer(nn.Module): raw_wave = None, ids = None, return_loss = False, text = None, text: Optional[List[str]] = None, text_embeds = None, cond_drop_prob = None ): Loading Loading @@ -577,13 +579,19 @@ class CoarseTransformer(nn.Module): cond_scale = 3, **kwargs ): logits = self.forward(*args, cond_drop_prob = 0., **kwargs) semantic_logits, coarse_logits = self.forward(*args, cond_drop_prob = 0., **kwargs) if cond_scale == 1 or not self.has_condition: return logits return semantic_logits, coarse_logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale null_semantic_logits, null_coarse_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) scaled_semantic_logits = None if exists(null_semantic_logits): scaled_semantic_logits = null_semantic_logits + (semantic_logits - null_semantic_logits) * cond_scale scaled_coarse_logits = null_coarse_logits + (coarse_logits - null_coarse_logits) * cond_scale return scaled_semantic_logits, scaled_coarse_logits def forward( self, Loading @@ -591,7 +599,7 @@ class CoarseTransformer(nn.Module): semantic_token_ids, coarse_token_ids, self_attn_mask = None, text = None, text: Optional[List[str]] = None, text_embeds = None, cond_drop_prob = None, return_only_coarse_logits = False Loading Loading @@ -719,22 +727,29 @@ class FineTransformer(nn.Module): cond_scale = 3, **kwargs ): logits = self.forward(*args, cond_drop_prob = 0., **kwargs) coarse_logits, fine_logits = self.forward(*args, cond_drop_prob = 0., **kwargs) if cond_scale == 1 or not self.has_condition: return logits return coarse_logits, fine_logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale null_coarse_logits, null_fine_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) scaled_coarse_logits = None if exists(null_coarse_logits): scaled_coarse_logits = null_coarse_logits + (coarse_logits - null_coarse_logits) * cond_scale scaled_fine_logits = null_fine_logits + (fine_logits - null_fine_logits) * cond_scale return scaled_coarse_logits, scaled_fine_logits def forward( self, coarse_token_ids, fine_token_ids, text = None, text: Optional[List[str]] = None, text_embeds = None, cond_drop_prob = None, self_attn_mask = None self_attn_mask = None, return_only_fine_logits = False ): b, device = coarse_token_ids.shape[0], coarse_token_ids.device has_text = exists(text) or exists(text_embeds) Loading Loading @@ -794,6 +809,9 @@ class FineTransformer(nn.Module): pred_coarse_tokens = rearrange(pred_coarse_tokens, 'b (n q) d -> b n q d', q = self.num_coarse_quantizers) coarse_logits = None if not return_only_fine_logits: coarse_logits = einsum('q c d, b n q d -> b n q c', self.coarse_logit_weights, pred_coarse_tokens) coarse_logits = rearrange(coarse_logits, 'b n q c -> b (n q) c') Loading Loading @@ -858,12 +876,15 @@ class CoarseTransformerWrapper(nn.Module): self, *, semantic_token_ids, text: Optional[List[str]] = None, text_embeds = None, max_time_steps = 512, cond_scale = 3., filter_thres = 0.9, temperature = 1., reshape_output = True, reconstruct_wave = False reconstruct_wave = False, **kwargs ): batch, device = semantic_token_ids.shape[0], self.device Loading @@ -871,18 +892,31 @@ class CoarseTransformerWrapper(nn.Module): coarse_token_ids = torch.empty((batch, 0), device = device, dtype = torch.long) # derive text embeddings if needed has_text = exists(text) or exists(text_embeds) assert not (self.transformer.has_condition ^ has_text) if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.transformer.embed_text(text, output_device = device) # initialize init_coarse_time_step = coarse_token_ids.shape[-1] output = coarse_token_ids.clone() sampled_coarse_token_ids = coarse_token_ids.clone() for time_step in range(init_coarse_time_step, max_time_steps): for time_step in tqdm(range(init_coarse_time_step, max_time_steps), desc = 'generating coarse'): for ind in range(self.num_coarse_quantizers): is_last_step = ind == (self.num_coarse_quantizers - 1) _, coarse_logits = self.transformer.forward_with_cond_scale( coarse_token_ids = coarse_token_ids, semantic_token_ids = semantic_token_ids, text_embeds = text_embeds, cond_scale = cond_scale, return_only_coarse_logits = True return_only_coarse_logits = True, **kwargs ) last_coarse_logits = coarse_logits[:, -1] Loading @@ -894,20 +928,20 @@ class CoarseTransformerWrapper(nn.Module): sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) sampled = rearrange(sampled, 'b -> b 1') output = torch.cat((output, sampled), dim = -1) sampled_coarse_token_ids = torch.cat((sampled_coarse_token_ids, sampled), dim = -1) output = mask_out_after_eos_id(output, self.eos_id, include_eos = False) sampled_coarse_token_ids = mask_out_after_eos_id(sampled_coarse_token_ids, self.eos_id, include_eos = False) if reshape_output or reconstruct_wave: output = rearrange(output, 'b (n q) -> b n q', q = self.num_coarse_quantizers) sampled_coarse_token_ids = rearrange(sampled_coarse_token_ids, 'b (n q) -> b n q', q = self.num_coarse_quantizers) if reconstruct_wave: assert exists(self.soundstream) wav = self.soundstream.decode_from_codebook_indices(output) wav = self.soundstream.decode_from_codebook_indices(sampled_coarse_token_ids) wav = rearrange(wav, 'b 1 n -> b n') return wav return output return sampled_coarse_token_ids def forward( self, Loading Loading @@ -998,11 +1032,98 @@ class FineTransformerWrapper(nn.Module): self.soundstream = soundstream self.transformer = transformer self.num_fine_quantizers = transformer.num_fine_quantizers self.num_coarse_quantizers = transformer.num_coarse_quantizers self.eos_id = transformer.eos_id assert self.num_coarse_quantizers > 0 self.pad_id = pad_id @property def device(self): return next(self.parameters()).device @eval_decorator @torch.no_grad() def generate( self, *, coarse_token_ids, text: Optional[List[str]] = None, text_embeds = None, cond_scale = 3., filter_thres = 0.9, temperature = 1., reshape_output = True, reconstruct_wave = False, **kwargs ): coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)') batch, device = coarse_token_ids.shape[0], self.device coarse_token_ids = coarse_token_ids.to(device) # derive text embeddings if needed has_text = exists(text) or exists(text_embeds) assert not (self.transformer.has_condition ^ has_text) if not exists(text_embeds) and exists(text): with torch.no_grad(): text_embeds = self.transformer.embed_text(text, output_device = device) # initialize fine_token_ids = torch.empty((batch, 0), device = device, dtype = torch.long) init_fine_time_step = fine_token_ids.shape[-1] max_time_steps = coarse_token_ids.shape[1] // self.num_coarse_quantizers sampled_fine_token_ids = fine_token_ids.clone() for time_step in tqdm(range(init_fine_time_step, max_time_steps), desc = 'generating fine'): for ind in range(self.num_fine_quantizers): is_last_step = ind == (self.num_fine_quantizers - 1) _, fine_logits = self.transformer.forward_with_cond_scale( coarse_token_ids = coarse_token_ids, fine_token_ids = fine_token_ids, text_embeds = text_embeds, cond_scale = cond_scale, return_only_fine_logits = True, **kwargs ) last_fine_logits = fine_logits[:, -1] if not is_last_step: last_fine_logits[:, -1] = float('-inf') # prevent from eos if not last quantizer step, but move this to masking logic within the transformer at some point, for both training and eval filtered_logits = top_k(last_fine_logits, thres = filter_thres) sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) sampled = rearrange(sampled, 'b -> b 1') sampled_fine_token_ids = torch.cat((sampled_fine_token_ids, sampled), dim = -1) sampled_fine_token_ids = mask_out_after_eos_id(sampled_fine_token_ids, self.eos_id, include_eos = False) if reshape_output or reconstruct_wave: sampled_fine_token_ids = rearrange(sampled_fine_token_ids, 'b (n q) -> b n q', q = self.num_fine_quantizers) if reconstruct_wave: assert exists(self.soundstream) coarse_token_ids = rearrange(coarse_token_ids, 'b (n q) -> b n q', q = self.num_coarse_quantizers) coarse_and_fine_ids = torch.cat((coarse_token_ids, sampled_fine_token_ids), dim = -1) wav = self.soundstream.decode_from_codebook_indices(coarse_and_fine_ids) wav = rearrange(wav, 'b 1 n -> b n') return wav return sampled_fine_token_ids def forward( self, *, Loading Loading @@ -1072,16 +1193,60 @@ class AudioLM(nn.Module): def __init__( self, *, wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]], soundstream: SoundStream, semantic_transformer: SemanticTransformer, coarse_transformer: CoarseTransformer, fine_transformer: FineTransformer, fine_transformer: FineTransformer ): super().__init__() self.soundstream = soundstream self.semantic = semantic_transformer self.coarse = coarse_transformer self.fine = fine_transformer def forward(self, x): raise NotImplemented self.coarse = CoarseTransformerWrapper( wav2vec = wav2vec, soundstream = soundstream, transformer = coarse_transformer, unique_consecutive = semantic_transformer.unique_consecutive ) self.fine = FineTransformerWrapper( soundstream = soundstream, transformer = fine_transformer ) @property def device(self): return next(self.parameters()).device @eval_decorator @torch.no_grad() def forward( self, *, batch_size = 1, text: Optional[List[str]] = None, prime_wave = None, max_length = 2048 ): if exists(prime_wave): prime_wave = prime_wave.to(self.device) semantic_token_ids = self.semantic.generate( text = text, batch_size = batch_size, prime_wave = prime_wave, max_length = max_length ) coarse_token_ids = self.coarse.generate( text = text, semantic_token_ids = semantic_token_ids ) generated_wave = self.fine.generate( text = text, coarse_token_ids = coarse_token_ids, reconstruct_wave = True ) return generated_wave
setup.py +2 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.41', version = '0.0.42', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading @@ -26,6 +26,7 @@ setup( 'torch>=1.6', 'torchaudio', 'transformers', 'tqdm', 'typeguard', 'vector-quantize-pytorch>=0.10.11' ], Loading