Commit d9152df6 authored by Phil Wang's avatar Phil Wang
Browse files

code for sampling semantic token ids from semantic transformer, input can be...

code for sampling semantic token ids from semantic transformer, input can be prompt token ids, raw wave, or batch size to start from the start token
parent 3ea0bcf4
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -63,6 +63,10 @@ loss = semantic_transformer(
)

loss.backward()

# after much training above

sample = semantic_transformer.generate(max_length = 128) # (1, < 128) - may terminate early if it detects [eos]
```

ex. `CoarseTransformer`
@@ -177,6 +181,7 @@ loss.backward()
- [ ] abstract out conditioning + classifier free guidance into external module or potentially a package
- [ ] add option to use flash attention
- [ ] simplify training even more within AudioLM class
- [ ] handle when generating semantic tokens, that last logits may not be necessarily the last in the sequence given unique consecutive processing

## Citations

+123 −2
Original line number Diff line number Diff line
@@ -36,6 +36,15 @@ def remainder_needed_until_multiple(n, mult):
def round_down_nearest_multiple(val, mult):
    return (val // mult) * mult

def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

# attention related utils

def grad_shrink(t, alpha = 0.1):
@@ -334,6 +343,81 @@ class SemanticTransformer(nn.Module):
        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, grad_shrink_alpha = grad_shrink_alpha, **kwargs)
        self.to_logits = nn.Linear(dim, num_semantic_tokens + 1)

    @property
    def device(self):
        return next(self.parameters()).device

    @eval_decorator
    @torch.no_grad()
    def generate(
        self,
        *,
        max_length,
        prime_wave = None,
        prime_ids = None,
        batch_size = 1,
        cond_scale = 3,
        filter_thres = 0.9,
        temperature = 1.,
        **kwargs
    ):
        device = self.device

        # derive wav2vec ids from the input wave

        if exists(prime_wave):
            assert not exists(prime_ids)
            assert exists(self.wav2vec)
            ids = self.wav2vec(prime_wave, flatten = False)
        elif exists(prime_ids):
            ids = prime_ids
        else:
            ids = torch.empty((batch_size, 0), dtype = torch.long, device = device)

        if self.unique_consecutive:
            ids = batch_unique_consecutive(ids, pad_value = self.pad_id)

        # start length and get running id output

        start_length = ids.shape[-1]
        output = ids.clone()

        # sample from transformer

        for ind in range(start_length, max_length):

            logits = self.forward_with_cond_scale(
                ids = output,
                **kwargs
            )

            last_logits = logits[:, -1]
            filtered_logits = top_k(last_logits, thres = filter_thres)
            sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

            sampled = rearrange(sampled, 'b -> b 1')
            output = torch.cat((output, sampled), dim = -1)

            if all_rows_have_eos_id(output, self.eos_id):
                break

        output = mask_out_after_eos_id(output, self.eos_id)
        return output

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3,
        **kwargs
    ):
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)

        if cond_scale == 1 or not self.has_condition:
            return logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

    def forward(
        self,
        *,
@@ -344,7 +428,7 @@ class SemanticTransformer(nn.Module):
        text_embeds = None,
        cond_drop_prob = None
    ):
        device = next(self.parameters()).device
        device = self.device

        assert exists(raw_wave) ^ exists(ids)

@@ -354,6 +438,7 @@ class SemanticTransformer(nn.Module):

        b = ids.shape[0]

        if self.training:
            ids = append_eos_id(ids, self.eos_id)

        if self.unique_consecutive:
@@ -441,6 +526,24 @@ class CoarseTransformer(nn.Module):
        self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens + 1)
        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim))

    @property
    def device(self):
        return next(self.parameters()).device

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3,
        **kwargs
    ):
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)

        if cond_scale == 1 or not self.has_condition:
            return logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

    def forward(
        self,
        *,
@@ -553,6 +656,24 @@ class FineTransformer(nn.Module):
        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim))
        self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size_with_eos, dim))

    @property
    def device(self):
        return next(self.parameters()).device

    def forward_with_cond_scale(
        self,
        *args,
        cond_scale = 3,
        **kwargs
    ):
        logits = self.forward(*args, cond_drop_prob = 0., **kwargs)

        if cond_scale == 1 or not self.has_condition:
            return logits

        null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
        return null_logits + (logits - null_logits) * cond_scale

    def forward(
        self,
        coarse_token_ids,
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.28',
  version = '0.0.29',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',