Commit ea194173 authored by Phil Wang's avatar Phil Wang
Browse files

handle unique consecutive issue with generating semantic token ids

parent 0806f62b
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -168,6 +168,7 @@ loss.backward()
- [x] accommodate variable lengthed audio, bring in eos token
- [x] make sure unique consecutive works with coarse transformer
- [x] pretty printing all discriminator losses to log
- [x] handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing

- [ ] complete full training code for soundstream, taking care of discriminator training
- [ ] figure out how to do the normalization across each dimension mentioned in the paper, but ignore it for v1 of the framework
@@ -181,7 +182,6 @@ loss.backward()
- [ ] abstract out conditioning + classifier free guidance into external module or potentially a package
- [ ] add option to use flash attention
- [ ] simplify training even more within AudioLM class
- [ ] handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing

## Citations

+11 −2
Original line number Diff line number Diff line
@@ -390,9 +390,13 @@ class SemanticTransformer(nn.Module):

        # start length and get running id output

        batch = ids.shape[0]
        start_length = ids.shape[-1]
        output = ids.clone()

        batch_range = rearrange(torch.arange(batch, device = device), 'b -> b 1')
        last_logit_indices = (ids != self.pad_id).sum(dim = -1).long()

        # sample from transformer

        for ind in range(start_length, max_length):
@@ -403,7 +407,10 @@ class SemanticTransformer(nn.Module):
                **kwargs
            )

            last_logits = logits[:, -1]
            last_logits = logits[batch_range, last_logit_indices]

            last_logits = rearrange(last_logits, 'b 1 c -> b c')

            filtered_logits = top_k(last_logits, thres = filter_thres)
            sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

@@ -413,7 +420,9 @@ class SemanticTransformer(nn.Module):
            if all_rows_have_eos_id(output, self.eos_id):
                break

        output = mask_out_after_eos_id(output, self.pad_id)
            last_logit_indices += 1

        output = mask_out_after_eos_id(output, self.pad_id, include_eos = False)
        return output

    def forward_with_cond_scale(
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.30',
  version = '0.0.31',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',