Loading README.md +5 −0 Original line number Diff line number Diff line Loading @@ -63,6 +63,10 @@ loss = semantic_transformer( ) loss.backward() # after much training above sample = semantic_transformer.generate(max_length = 128) # (1, < 128) - may terminate early if it detects [eos] ``` ex. `CoarseTransformer` Loading Loading @@ -177,6 +181,7 @@ loss.backward() - [ ] abstract out conditioning + classifier free guidance into external module or potentially a package - [ ] add option to use flash attention - [ ] simplify training even more within AudioLM class - [ ] handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing ## Citations Loading audiolm_pytorch/audiolm_pytorch.py +123 −2 Original line number Diff line number Diff line Loading @@ -36,6 +36,15 @@ def remainder_needed_until_multiple(n, mult): def round_down_nearest_multiple(val, mult): return (val // mult) * mult def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training model.eval() out = fn(model, *args, **kwargs) model.train(was_training) return out return inner # attention related utils def grad_shrink(t, alpha = 0.1): Loading Loading @@ -334,6 +343,81 @@ class SemanticTransformer(nn.Module): self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.to_logits = nn.Linear(dim, num_semantic_tokens + 1) @property def device(self): return next(self.parameters()).device @eval_decorator @torch.no_grad() def generate( self, *, max_length, prime_wave = None, prime_ids = None, batch_size = 1, cond_scale = 3, filter_thres = 0.9, temperature = 1., **kwargs ): device = self.device # derive wav2vec ids from the input wave if exists(prime_wave): assert not exists(prime_ids) assert exists(self.wav2vec) ids = self.wav2vec(prime_wave, flatten = False) elif exists(prime_ids): ids = prime_ids else: ids = torch.empty((batch_size, 0), dtype = torch.long, device = device) if self.unique_consecutive: ids = batch_unique_consecutive(ids, pad_value = self.pad_id) # start length and get running id output start_length = ids.shape[-1] output = ids.clone() # sample from transformer for ind in range(start_length, max_length): logits = self.forward_with_cond_scale( ids = output, **kwargs ) last_logits = logits[:, -1] filtered_logits = top_k(last_logits, thres = filter_thres) sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) sampled = rearrange(sampled, 'b -> b 1') output = torch.cat((output, sampled), dim = -1) if all_rows_have_eos_id(output, self.eos_id): break output = mask_out_after_eos_id(output, self.eos_id) return output def forward_with_cond_scale( self, *args, cond_scale = 3, **kwargs ): logits = self.forward(*args, cond_drop_prob = 0., **kwargs) if cond_scale == 1 or not self.has_condition: return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( self, *, Loading @@ -344,7 +428,7 @@ class SemanticTransformer(nn.Module): text_embeds = None, cond_drop_prob = None ): device = next(self.parameters()).device device = self.device assert exists(raw_wave) ^ exists(ids) Loading @@ -354,6 +438,7 @@ class SemanticTransformer(nn.Module): b = ids.shape[0] if self.training: ids = append_eos_id(ids, self.eos_id) if self.unique_consecutive: Loading Loading @@ -441,6 +526,24 @@ class CoarseTransformer(nn.Module): self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens + 1) self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim)) @property def device(self): return next(self.parameters()).device def forward_with_cond_scale( self, *args, cond_scale = 3, **kwargs ): logits = self.forward(*args, cond_drop_prob = 0., **kwargs) if cond_scale == 1 or not self.has_condition: return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( self, *, Loading Loading @@ -553,6 +656,24 @@ class FineTransformer(nn.Module): self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim)) self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size_with_eos, dim)) @property def device(self): return next(self.parameters()).device def forward_with_cond_scale( self, *args, cond_scale = 3, **kwargs ): logits = self.forward(*args, cond_drop_prob = 0., **kwargs) if cond_scale == 1 or not self.has_condition: return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( self, coarse_token_ids, Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.28', version = '0.0.29', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
README.md +5 −0 Original line number Diff line number Diff line Loading @@ -63,6 +63,10 @@ loss = semantic_transformer( ) loss.backward() # after much training above sample = semantic_transformer.generate(max_length = 128) # (1, < 128) - may terminate early if it detects [eos] ``` ex. `CoarseTransformer` Loading Loading @@ -177,6 +181,7 @@ loss.backward() - [ ] abstract out conditioning + classifier free guidance into external module or potentially a package - [ ] add option to use flash attention - [ ] simplify training even more within AudioLM class - [ ] handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing ## Citations Loading
audiolm_pytorch/audiolm_pytorch.py +123 −2 Original line number Diff line number Diff line Loading @@ -36,6 +36,15 @@ def remainder_needed_until_multiple(n, mult): def round_down_nearest_multiple(val, mult): return (val // mult) * mult def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training model.eval() out = fn(model, *args, **kwargs) model.train(was_training) return out return inner # attention related utils def grad_shrink(t, alpha = 0.1): Loading Loading @@ -334,6 +343,81 @@ class SemanticTransformer(nn.Module): self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs) self.to_logits = nn.Linear(dim, num_semantic_tokens + 1) @property def device(self): return next(self.parameters()).device @eval_decorator @torch.no_grad() def generate( self, *, max_length, prime_wave = None, prime_ids = None, batch_size = 1, cond_scale = 3, filter_thres = 0.9, temperature = 1., **kwargs ): device = self.device # derive wav2vec ids from the input wave if exists(prime_wave): assert not exists(prime_ids) assert exists(self.wav2vec) ids = self.wav2vec(prime_wave, flatten = False) elif exists(prime_ids): ids = prime_ids else: ids = torch.empty((batch_size, 0), dtype = torch.long, device = device) if self.unique_consecutive: ids = batch_unique_consecutive(ids, pad_value = self.pad_id) # start length and get running id output start_length = ids.shape[-1] output = ids.clone() # sample from transformer for ind in range(start_length, max_length): logits = self.forward_with_cond_scale( ids = output, **kwargs ) last_logits = logits[:, -1] filtered_logits = top_k(last_logits, thres = filter_thres) sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) sampled = rearrange(sampled, 'b -> b 1') output = torch.cat((output, sampled), dim = -1) if all_rows_have_eos_id(output, self.eos_id): break output = mask_out_after_eos_id(output, self.eos_id) return output def forward_with_cond_scale( self, *args, cond_scale = 3, **kwargs ): logits = self.forward(*args, cond_drop_prob = 0., **kwargs) if cond_scale == 1 or not self.has_condition: return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( self, *, Loading @@ -344,7 +428,7 @@ class SemanticTransformer(nn.Module): text_embeds = None, cond_drop_prob = None ): device = next(self.parameters()).device device = self.device assert exists(raw_wave) ^ exists(ids) Loading @@ -354,6 +438,7 @@ class SemanticTransformer(nn.Module): b = ids.shape[0] if self.training: ids = append_eos_id(ids, self.eos_id) if self.unique_consecutive: Loading Loading @@ -441,6 +526,24 @@ class CoarseTransformer(nn.Module): self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens + 1) self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim)) @property def device(self): return next(self.parameters()).device def forward_with_cond_scale( self, *args, cond_scale = 3, **kwargs ): logits = self.forward(*args, cond_drop_prob = 0., **kwargs) if cond_scale == 1 or not self.has_condition: return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( self, *, Loading Loading @@ -553,6 +656,24 @@ class FineTransformer(nn.Module): self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim)) self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size_with_eos, dim)) @property def device(self): return next(self.parameters()).device def forward_with_cond_scale( self, *args, cond_scale = 3, **kwargs ): logits = self.forward(*args, cond_drop_prob = 0., **kwargs) if cond_scale == 1 or not self.has_condition: return logits null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) return null_logits + (logits - null_logits) * cond_scale def forward( self, coarse_token_ids, Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.28', version = '0.0.29', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading