Commit 0ac5e4f5 authored by Phil Wang's avatar Phil Wang
Browse files

rough sketch of all three transformers finished

parent 80a3fad4
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -52,10 +52,11 @@ loss.backward()

## Todo

- [x] complete CoarseTransformer

- [ ] complete full training code for soundstream, taking care of discriminator training
- [ ] use huggingface wav2vec for embeddings, use VQ library for learning the kmeans through reconstruction task
- [ ] figure out how to do the normalization across each dimension mentioned in the paper, but ignore it for v1 of the framework
- [ ] complete CoarseTransformer
- [ ] complete sampling code for both Coarse and Fine Transformers, which will be tricky
- [ ] accommodate variable lengthed audio, bring in eos token
- [ ] full transformer training code for all three transformers
+1 −1
Original line number Diff line number Diff line
@@ -2,4 +2,4 @@ from audiolm_pytorch.audiolm_pytorch import AudioLM
from audiolm_pytorch.audiolm_pytorch import SoundStream

from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransformer, FineTransformer
from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper
from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper
+119 −3
Original line number Diff line number Diff line
@@ -589,19 +589,72 @@ class CoarseTransformer(nn.Module):
        *,
        num_semantic_tokens,
        codebook_size,
        num_coarse_tokens,
        num_coarse_quantizers,
        dim,
        **kwargs
    ):
        super().__init__()
        self.start_token = nn.Parameter(torch.randn(dim))

        self.semantic_embedding = nn.Embedding(num_semantic_tokens, dim)
        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size, dim)

        self.transformer = Transformer(dim = dim, **kwargs)

        self.num_coarse_quantizers = num_coarse_quantizers

        self.to_semantic_logits = nn.Linear(dim, num_semantic_tokens)
        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size, dim))

    def forward(
        self,
        semantic_token_ids,
        coarse_token_ids,
    ):
        raise NotImplemented
        b = semantic_token_ids.shape[0]

        coarse_token_ids, semantic_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, semantic_token_ids))

        semantic_tokens = self.semantic_embedding(semantic_token_ids)
        coarse_tokens = self.coarse_embedding(coarse_token_ids)

        semantic_seq_len = semantic_tokens.shape[1]

        start_tokens = repeat(self.start_token, 'd -> b 1 d', b = b)

        tokens = torch.cat((start_tokens, semantic_tokens, coarse_tokens), dim = 1)

        tokens = self.transformer(tokens)

        pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, semantic_seq_len:]

        # semantic logits

        semantic_logits = self.to_semantic_logits(pred_semantic_tokens)

        # get coarse logits

        n = pred_coarse_tokens.shape[1]
        nq = round_down_nearest_multiple(n, self.num_coarse_quantizers)

        pred_coarse_tokens_groupable, pred_coarse_tokens_remainder = pred_coarse_tokens[:, :nq], pred_coarse_tokens[:, nq:]

        pred_coarse_tokens_groupable = rearrange(pred_coarse_tokens_groupable, 'b (n q) d -> b n q d', q = self.num_coarse_quantizers)

        coarse_logits_groupable = einsum('q c d, b n q d -> b n q c', self.coarse_logit_weights, pred_coarse_tokens_groupable)

        coarse_logits_groupable = rearrange(coarse_logits_groupable, 'b n q c -> b (n q) c')

        remainder_num_quantizers = pred_coarse_tokens_remainder.shape[1]

        if remainder_num_quantizers > 0:
            coarse_logits_remainder = einsum('q c d, b q d -> b q c', self.coarse_logit_weights[:remainder_num_quantizers], pred_coarse_tokens_remainder)

            coarse_logits = torch.cat((coarse_logits_groupable, coarse_logits_remainder), dim = 1)
        else:
            coarse_logits = coarse_logits_groupable

        return semantic_logits, coarse_logits

class FineTransformer(nn.Module):
    def __init__(
@@ -622,6 +675,7 @@ class FineTransformer(nn.Module):

        self.transformer = Transformer(dim = dim, **kwargs)

        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers
        self.num_fine_quantizers = num_fine_quantizers

@@ -672,7 +726,7 @@ class FineTransformer(nn.Module):
        remainder_num_quantizers = pred_fine_tokens_remainder.shape[1]

        if remainder_num_quantizers > 0:
            fine_logits_remainder = einsum('q c d, b n q d -> b n q c', self.fine_logit_weights[:remainder_num_quantizers], pred_fine_tokens_remainder)
            fine_logits_remainder = einsum('q c d, b q d -> b q c', self.fine_logit_weights[:remainder_num_quantizers], pred_fine_tokens_remainder)

            fine_logits = torch.cat((fine_logits_groupable, fine_logits_remainder), dim = 1)
        else:
@@ -744,6 +798,68 @@ class FineTransformerWrapper(nn.Module):

        return (coarse_loss + fine_loss) * 0.5

class CoarseTransformerWrapper(nn.Module):
    def __init__(
        self,
        *,
        soundstream: Optional[SoundStream],
        transformer: FineTransformer,
        num_coarse_quantize = 3
    ):
        super().__init__()
        self.soundstream = soundstream
        self.transformer = transformer

        assert num_coarse_quantize > 0
        self.num_coarse_quantize = num_coarse_quantize

    def forward(
        self,
        *,
        semantic_token_ids,
        raw_wave = None,
        coarse_token_ids = None,
        return_loss = False
    ):
        assert exists(raw_wave) ^ exists(coarse_token_ids), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'

        if exists(raw_wave):
            assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'

            with torch.no_grad():
                self.soundstream.eval()
                _, indices, _ = self.soundstream(raw_wave, return_encoded = True)
                coarse_token_ids, _ = indices[..., :self.num_coarse_quantize], indices[..., self.num_coarse_quantize:]

        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
        semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)')

        if return_loss:
            semantic_labels, coarse_labels = semantic_token_ids, coarse_token_ids.clone()
            coarse_token_ids = coarse_token_ids[:, :-1]

        semantic_logits, coarse_logits = self.transformer(
            semantic_token_ids = semantic_token_ids,
            coarse_token_ids = coarse_token_ids
        )

        if not return_loss:
            return semantic_logits, coarse_logits

        coarse_logits, semantic_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, semantic_logits))

        semantic_loss = F.cross_entropy(
            semantic_logits,
            semantic_labels
        )

        coarse_loss = F.cross_entropy(
            coarse_logits,
            coarse_labels
        )

        return (semantic_loss + coarse_loss) * 0.5

# audio LM

class AudioLM(nn.Module):