Commit 2b6c5662 authored by Phil Wang's avatar Phil Wang
Browse files

follow researcher @eonglints advice and add unique consecutive for semantic...

follow researcher @eonglints advice and add unique consecutive for semantic tokens in semantic transformer, add eos to all three transformers in preparation for variable sequence lengths, make a note to refactor coarse transformer for unique consecutive semantic token ids
parent ccf9c1d8
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -60,12 +60,14 @@ loss.backward()
- [x] use fairseq vq-wav2vec for embeddings
- [x] add conditioning
- [x] add classifier free guidance
- [x] add unique consecutive for 
- [x] incorporate ability to use hubert intermediate features as semantic tokens, recommended by <a href="https://github.com/lucidrains/audiolm-pytorch/discussions/13">eonglints</a>
- [x] accommodate variable lengthed audio, bring in eos token

- [ ] incorporate ability to use hubert intermediate features as semantic tokens, recommended by <a href="https://github.com/lucidrains/audiolm-pytorch/discussions/13">eonglints</a>
- [ ] refactor coarse transformer embeddings so that unique_consecutive can be applied to semantic tokens and can be variable lengthed
- [ ] complete full training code for soundstream, taking care of discriminator training
- [ ] figure out how to do the normalization across each dimension mentioned in the paper, but ignore it for v1 of the framework
- [ ] complete sampling code for both Coarse and Fine Transformers, which will be tricky
- [ ] accommodate variable lengthed audio, bring in eos token
- [ ] full transformer training code for all three transformers
- [ ] make sure full inference with or without prompting works on the `AudioLM` class
- [ ] offer option to weight tie coarse, fine, and semantic embeddings across the 3 hierarchical transformers
+70 −14
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ import torch
from torch import nn, einsum
from torch.autograd import grad as torch_grad
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from einops import rearrange, repeat

@@ -27,6 +28,9 @@ def default(val, d):
def ceil_div(numer, denom):
    return (numer + denom - 1) // denom

def remainder_needed_until_multiple(n, mult):
    return (ceil_div(n, mult) * mult) - n

def round_down_nearest_multiple(val, mult):
    return (val // mult) * mult

@@ -69,6 +73,20 @@ def prob_mask_like(shape, prob, device):
    else:
        return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# removing unique consecutives in the semantic token ids
# important detail noted by @eonglints

def append_eos_id(ids, eos_id):
    b, device = ids.shape[0], ids.device
    eos_ids = torch.ones(1, device = device).long() * eos_id
    eos_ids = repeat(eos_ids, '1 -> b 1', b = b)
    ids = torch.cat((ids, eos_ids), dim = -1)
    return ids

def batch_unique_consecutive(t, pad_value = 0.):
    unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim = 0)]
    return pad_sequence(unique_arr, batch_first = True, padding_value = pad_value)

# discriminators

class MultiScaleDiscriminator(nn.Module):
@@ -624,6 +642,8 @@ class SemanticTransformer(nn.Module):
        has_condition = False,
        cond_drop_prob = 0.5,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        unique_consecutive = True,
        pad_id = -1,
        **kwargs
    ):
        super().__init__()
@@ -631,13 +651,17 @@ class SemanticTransformer(nn.Module):
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob

        self.unique_consecutive = unique_consecutive

        self.start_token = nn.Parameter(torch.randn(dim))

        self.semantic_embedding = nn.Embedding(num_semantic_tokens, dim)
        self.semantic_embedding = nn.Embedding(num_semantic_tokens + 1, dim)
        self.eos_id = num_semantic_tokens
        self.pad_id = pad_id

        self.wav2vec = wav2vec
        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, **kwargs)
        self.to_logits = nn.Linear(dim, num_semantic_tokens)
        self.to_logits = nn.Linear(dim, num_semantic_tokens + 1)

    def forward(
        self,
@@ -659,6 +683,11 @@ class SemanticTransformer(nn.Module):

        b = ids.shape[0]

        ids = append_eos_id(ids, self.eos_id)

        if self.unique_consecutive:
            ids = batch_unique_consecutive(ids)

        has_text = exists(text) or exists(text_embed)
        assert not (self.has_condition ^ has_text)

@@ -690,7 +719,8 @@ class SemanticTransformer(nn.Module):

        loss = F.cross_entropy(
            rearrange(logits, 'b n c -> b c n'),
            labels
            labels,
            ignore_index = self.pad_id
        )

        return loss
@@ -716,8 +746,12 @@ class CoarseTransformer(nn.Module):

        self.start_token = nn.Parameter(torch.randn(dim))

        self.semantic_embedding = nn.Embedding(num_semantic_tokens, dim)
        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size, dim)
        self.semantic_eos_id = num_semantic_tokens
        self.semantic_embedding = nn.Embedding(num_semantic_tokens + 1, dim)

        self.coarse_eos_id = codebook_size
        codebook_size_with_eos = codebook_size + 1
        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim)

        self.wav2vec = wav2vec
        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, **kwargs)
@@ -725,8 +759,8 @@ class CoarseTransformer(nn.Module):
        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers

        self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens)
        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size, dim))
        self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens + 1)
        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim))

    def forward(
        self,
@@ -821,8 +855,12 @@ class FineTransformer(nn.Module):

        self.start_token = nn.Parameter(torch.randn(dim))

        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size, dim)
        self.fine_embedding = nn.Embedding(num_fine_quantizers * codebook_size, dim)
        codebook_size_with_eos = codebook_size + 1

        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim)
        self.fine_embedding = nn.Embedding(num_fine_quantizers * codebook_size_with_eos, dim)

        self.eos_id = codebook_size

        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, **kwargs)

@@ -830,8 +868,8 @@ class FineTransformer(nn.Module):
        self.num_coarse_quantizers = num_coarse_quantizers
        self.num_fine_quantizers = num_fine_quantizers

        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size, dim))
        self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size, dim))
        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim))
        self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size_with_eos, dim))

    def forward(
        self,
@@ -883,16 +921,25 @@ class FineTransformer(nn.Module):

        # get coarse logits

        pred_coarse_seq_len = pred_coarse_tokens.shape[1]

        padding = remainder_needed_until_multiple(pred_coarse_seq_len, self.num_coarse_quantizers)

        if padding != 0:
            pred_coarse_tokens = F.pad(pred_coarse_tokens, (0, 0, 0, padding), value = 0.)

        pred_coarse_tokens = rearrange(pred_coarse_tokens, 'b (n q) d -> b n q d', q = self.num_coarse_quantizers)

        coarse_logits = einsum('q c d, b n q d -> b n q c', self.coarse_logit_weights, pred_coarse_tokens)

        coarse_logits = rearrange(coarse_logits, 'b n q c -> b (n q) c')

        coarse_logits = coarse_logits[:, :pred_coarse_seq_len]

        # get fine logits

        n = pred_fine_tokens.shape[1]
        nq = round_down_nearest_multiple(n, self.num_fine_quantizers)
        pred_fine_seq_len = pred_fine_tokens.shape[1]
        nq = round_down_nearest_multiple(pred_fine_seq_len, self.num_fine_quantizers)

        pred_fine_tokens_groupable, pred_fine_tokens_remainder = pred_fine_tokens[:, :nq], pred_fine_tokens[:, nq:]

@@ -952,6 +999,9 @@ class FineTransformerWrapper(nn.Module):
        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
        fine_token_ids = rearrange(fine_token_ids, 'b ... -> b (...)')

        coarse_token_ids = append_eos_id(coarse_token_ids, self.transformer.eos_id)
        fine_token_ids = append_eos_id(fine_token_ids, self.transformer.eos_id)

        if return_loss:
            coarse_labels, fine_labels = coarse_token_ids, fine_token_ids.clone()
            fine_token_ids = fine_token_ids[:, :-1]
@@ -986,7 +1036,8 @@ class CoarseTransformerWrapper(nn.Module):
        transformer: FineTransformer,
        soundstream: Optional[SoundStream]  = None,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        num_coarse_quantize = 3
        num_coarse_quantize = 3,
        unique_consecutive = False
    ):
        super().__init__()
        self.soundstream = soundstream
@@ -994,6 +1045,8 @@ class CoarseTransformerWrapper(nn.Module):

        self.transformer = transformer

        assert not unique_consecutive, 'not implemented yet'

        assert num_coarse_quantize > 0
        self.num_coarse_quantize = num_coarse_quantize

@@ -1025,6 +1078,9 @@ class CoarseTransformerWrapper(nn.Module):
        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
        semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)')

        coarse_token_ids = append_eos_id(coarse_token_ids, self.transformer.coarse_eos_id)
        semantic_token_ids = append_eos_id(semantic_token_ids, self.transformer.semantic_eos_id)

        if return_loss:
            semantic_labels, coarse_labels = semantic_token_ids, coarse_token_ids.clone()
            coarse_token_ids = coarse_token_ids[:, :-1]
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.9',
  version = '0.0.10',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',