Commit 3ea0bcf4 authored by Phil Wang's avatar Phil Wang
Browse files

prepare some sampling functions

parent 6a0c7e73
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -163,6 +163,7 @@ loss.backward()
- [x] incorporate ability to use hubert intermediate features as semantic tokens, recommended by <a href="https://github.com/lucidrains/audiolm-pytorch/discussions/13">eonglints</a>
- [x] accommodate variable lengthed audio, bring in eos token
- [x] make sure unique consecutive works with coarse transformer
- [x] pretty printing all discriminator losses to log

- [ ] complete full training code for soundstream, taking care of discriminator training
- [ ] figure out how to do the normalization across each dimension mentioned in the paper, but ignore it for v1 of the framework
@@ -175,7 +176,6 @@ loss.backward()
- [ ] test with speech synthesis for starters
- [ ] abstract out conditioning + classifier free guidance into external module or potentially a package
- [ ] add option to use flash attention
- [ ] function for pretty printing all discriminator losses to log
- [ ] simplify training even more within AudioLM class

## Citations
+33 −0
Original line number Diff line number Diff line
@@ -41,6 +41,39 @@ def round_down_nearest_multiple(val, mult):
def grad_shrink(t, alpha = 0.1):
    return t * alpha + t.detach() * (1 - alpha)

# sampling helpers

def log(t, eps = 1e-20):
    return torch.log(t + eps)

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1):
    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)

def top_k(logits, thres = 0.5):
    num_logits = logits.shape[-1]
    k = max(int((1 - thres) * num_logits), 1)
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs

def mask_out_after_eos_id(t, eos_id, mask_value = -1, include_eos = True):
    eos_mask = (t == eos_id).float()

    if include_eos:
        eos_mask = F.pad(eos_mask, (1, -1))

    after_eos_mask = eos_mask.cumsum(dim = -1) > 0
    return t.masked_fill(after_eos_mask, mask_value)

def all_rows_have_eos_id(t, eos_id):
    eos_mask = (t == eos_id)
    return torch.any(eos_mask, dim = -1).all()

# classifier free guidance functions

def prob_mask_like(shape, prob, device):
+12 −1
Original line number Diff line number Diff line
@@ -256,9 +256,20 @@ class SoundStreamTrainer(nn.Module):
            discr_optimizer.step()
            discr_optimizer.zero_grad()

        # build pretty printed losses

        losses_str = f"{steps}: soundstream loss: {logs['loss']}"

        for key, loss in logs.items():
            if not key.startswith('scale:'):
                continue
            _, scale_factor = key.split(':')

            losses_str += f" | discr (scale {scale_factor}) loss: {loss:.2f}"

        # log

        self.print(f"{steps}: soundstream loss: {logs['loss']}")
        self.print(losses_str)

        # update exponential moving averaged generator

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.27',
  version = '0.0.28',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',