Commit 30f04de7 authored by Phil Wang's avatar Phil Wang
Browse files

add classifier free guidance training logic, cite

parent c17ee7d3
Loading
Loading
Loading
Loading
+19 −1
Original line number Diff line number Diff line
@@ -59,6 +59,7 @@ loss.backward()
- [x] complete CoarseTransformer
- [x] use fairseq vq-wav2vec for embeddings
- [x] add conditioning
- [x] add classifier free guidance

- [ ] incorporate ability to use hubert intermediate features as semantic tokens, recommended by <a href="https://github.com/lucidrains/audiolm-pytorch/discussions/13">eonglints</a>
- [ ] complete full training code for soundstream, taking care of discriminator training
@@ -71,7 +72,7 @@ loss.backward()
- [ ] DRY a little at the end
- [ ] figure out how to suppress logging in fairseq
- [ ] test with speech synthesis for starters
- [ ] add classifier free guidance
- [ ] abstract out conditioning + classifier free guidance into external module or potentially a package

## Citations

@@ -111,3 +112,20 @@ loss.backward()
    volume  = {abs/1911.02150}
}
```

```bibtex
@article{Ho2022ClassifierFreeDG,
    title   = {Classifier-Free Diffusion Guidance},
    author  = {Jonathan Ho},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2207.12598}
}
```

```bibtex
@misc{crowson2022,
    author  = {Katherine Crowson},
    url     = {https://twitter.com/rivershavewings}
}
```
+51 −10
Original line number Diff line number Diff line
@@ -56,6 +56,19 @@ def gradient_penalty(images, output, weight = 10):
    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()

# classifier free guidance functions

def uniform(shape, device):
    return torch.zeros(shape, device = device).float().uniform_(0, 1)

def prob_mask_like(shape, prob, device):
    if prob == 1:
        return torch.ones(shape, device = device, dtype = torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device = device, dtype = torch.bool)
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# discriminators

class MultiScaleDiscriminator(nn.Module):
@@ -609,12 +622,14 @@ class SemanticTransformer(nn.Module):
        dim,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        **kwargs
    ):
        super().__init__()
        self.has_condition = has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob

        self.start_token = nn.Parameter(torch.randn(dim))

@@ -631,10 +646,19 @@ class SemanticTransformer(nn.Module):
        ids = None,
        return_loss = False,
        text = None,
        text_embed = None
        text_embed = None,
        cond_drop_prob = None
    ):
        device = next(self.parameters()).device

        assert exists(raw_wave) ^ exists(ids)

        if not exists(ids):
            assert exists(self.wav2vec)
            ids = self.wav2vec(raw_wave, flatten = False)

        b = ids.shape[0]

        has_text = exists(text) or exists(text_embed)
        assert not (self.has_condition ^ has_text)

@@ -643,11 +667,11 @@ class SemanticTransformer(nn.Module):
                text_embeds = self.embed_text(text, output_device = device)
                text_mask = torch.any(text_embeds != 0, dim = -1)

        assert exists(raw_wave) ^ exists(ids)
        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        if not exists(ids):
            assert exists(self.wav2vec)
            ids = self.wav2vec(raw_wave, flatten = False)
        if cond_drop_prob > 0:
            keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
            text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

        if return_loss:
            labels, ids = ids.clone(), ids[:, :-1]
@@ -681,12 +705,14 @@ class CoarseTransformer(nn.Module):
        dim,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        **kwargs
    ):
        super().__init__()
        self.has_condition = has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob

        self.start_token = nn.Parameter(torch.randn(dim))

@@ -708,7 +734,8 @@ class CoarseTransformer(nn.Module):
        semantic_token_ids,
        coarse_token_ids,
        text = None,
        text_embed = None
        text_embed = None,
        cond_drop_prob = None
    ):
        b, device = semantic_token_ids.shape[0], semantic_token_ids.device

@@ -720,6 +747,12 @@ class CoarseTransformer(nn.Module):
                text_embeds = self.embed_text(text, output_device = device)
                text_mask = torch.any(text_embeds != 0, dim = -1)

        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        if cond_drop_prob > 0:
            keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
            text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

        coarse_token_ids, semantic_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, semantic_token_ids))

        offsets = self.codebook_size * torch.arange(self.num_coarse_quantizers, device = device)
@@ -778,11 +811,13 @@ class FineTransformer(nn.Module):
        dim,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
        **kwargs
    ):
        super().__init__()
        self.has_condition = has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob

        self.start_token = nn.Parameter(torch.randn(dim))

@@ -803,10 +838,10 @@ class FineTransformer(nn.Module):
        coarse_token_ids,
        fine_token_ids,
        text = None,
        text_embed = None
        text_embed = None,
        cond_drop_prob = None
    ):
        device = coarse_token_ids.device

        b, device = coarse_token_ids.shape[0], coarse_token_ids.device
        has_text = exists(text) or exists(text_embed)
        assert not (self.has_condition ^ has_text)

@@ -815,6 +850,12 @@ class FineTransformer(nn.Module):
                text_embeds = self.embed_text(text, output_device = device)
                text_mask = torch.any(text_embeds != 0, dim = -1)

        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        if cond_drop_prob > 0:
            keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
            text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

        coarse_token_ids, fine_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, fine_token_ids))

        b, n = coarse_token_ids.shape
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.8',
  version = '0.0.9',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',