Loading README.md +1 −1 Original line number Diff line number Diff line Loading @@ -168,6 +168,7 @@ loss.backward() - [x] accommodate variable lengthed audio, bring in eos token - [x] make sure unique consecutive works with coarse transformer - [x] pretty printing all discriminator losses to log - [x] handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing - [ ] complete full training code for soundstream, taking care of discriminator training - [ ] figure out how to do the normalization across each dimension mentioned in the paper, but ignore it for v1 of the framework Loading @@ -181,7 +182,6 @@ loss.backward() - [ ] abstract out conditioning + classifier free guidance into external module or potentially a package - [ ] add option to use flash attention - [ ] simplify training even more within AudioLM class - [ ] handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing ## Citations Loading audiolm_pytorch/audiolm_pytorch.py +11 −2 Original line number Diff line number Diff line Loading @@ -390,9 +390,13 @@ class SemanticTransformer(nn.Module): # start length and get running id output batch = ids.shape[0] start_length = ids.shape[-1] output = ids.clone() batch_range = rearrange(torch.arange(batch, device = device), 'b -> b 1') last_logit_indices = (ids != self.pad_id).sum(dim = -1).long() # sample from transformer for ind in range(start_length, max_length): Loading @@ -403,7 +407,10 @@ class SemanticTransformer(nn.Module): **kwargs ) last_logits = logits[:, -1] last_logits = logits[batch_range, last_logit_indices] last_logits = rearrange(last_logits, 'b 1 c -> b c') filtered_logits = top_k(last_logits, thres = filter_thres) sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) Loading @@ -413,7 +420,9 @@ class SemanticTransformer(nn.Module): if all_rows_have_eos_id(output, self.eos_id): break output = mask_out_after_eos_id(output, self.pad_id) last_logit_indices += 1 output = mask_out_after_eos_id(output, self.pad_id, include_eos = False) return output def forward_with_cond_scale( Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.30', version = '0.0.31', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
README.md +1 −1 Original line number Diff line number Diff line Loading @@ -168,6 +168,7 @@ loss.backward() - [x] accommodate variable lengthed audio, bring in eos token - [x] make sure unique consecutive works with coarse transformer - [x] pretty printing all discriminator losses to log - [x] handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing - [ ] complete full training code for soundstream, taking care of discriminator training - [ ] figure out how to do the normalization across each dimension mentioned in the paper, but ignore it for v1 of the framework Loading @@ -181,7 +182,6 @@ loss.backward() - [ ] abstract out conditioning + classifier free guidance into external module or potentially a package - [ ] add option to use flash attention - [ ] simplify training even more within AudioLM class - [ ] handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing ## Citations Loading
audiolm_pytorch/audiolm_pytorch.py +11 −2 Original line number Diff line number Diff line Loading @@ -390,9 +390,13 @@ class SemanticTransformer(nn.Module): # start length and get running id output batch = ids.shape[0] start_length = ids.shape[-1] output = ids.clone() batch_range = rearrange(torch.arange(batch, device = device), 'b -> b 1') last_logit_indices = (ids != self.pad_id).sum(dim = -1).long() # sample from transformer for ind in range(start_length, max_length): Loading @@ -403,7 +407,10 @@ class SemanticTransformer(nn.Module): **kwargs ) last_logits = logits[:, -1] last_logits = logits[batch_range, last_logit_indices] last_logits = rearrange(last_logits, 'b 1 c -> b c') filtered_logits = top_k(last_logits, thres = filter_thres) sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1) Loading @@ -413,7 +420,9 @@ class SemanticTransformer(nn.Module): if all_rows_have_eos_id(output, self.eos_id): break output = mask_out_after_eos_id(output, self.pad_id) last_logit_indices += 1 output = mask_out_after_eos_id(output, self.pad_id, include_eos = False) return output def forward_with_cond_scale( Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.30', version = '0.0.31', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading