Commit 09c79a04 authored by Phil Wang's avatar Phil Wang
Browse files

complete first pass at unique consecutive issue with semantic token ids, by...

complete first pass at unique consecutive issue with semantic token ids, by using key padding masking in coarse transformer in causal self attn
parent e5408fbd
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -63,8 +63,8 @@ loss.backward()
- [x] add unique consecutive for 
- [x] incorporate ability to use hubert intermediate features as semantic tokens, recommended by <a href="https://github.com/lucidrains/audiolm-pytorch/discussions/13">eonglints</a>
- [x] accommodate variable lengthed audio, bring in eos token
- [x] make sure unique consecutive works with coarse transformer

- [ ] refactor coarse transformer embeddings so that unique_consecutive can be applied to semantic tokens and can be variable lengthed
- [ ] complete full training code for soundstream, taking care of discriminator training
- [ ] figure out how to do the normalization across each dimension mentioned in the paper, but ignore it for v1 of the framework
- [ ] complete sampling code for both Coarse and Fine Transformers, which will be tricky
+25 −8
Original line number Diff line number Diff line
@@ -611,6 +611,7 @@ class Transformer(nn.Module):
    def forward(
        self,
        x,
        self_attn_mask = None,
        context = None,
        context_mask = None
    ):
@@ -619,7 +620,7 @@ class Transformer(nn.Module):
        rel_pos_bias = self.rel_pos_bias(n, n, device = device)

        for attn, cross_attn, ff in self.layers:
            x = attn(x, attn_bias = rel_pos_bias) + x
            x = attn(x, attn_bias = rel_pos_bias, mask = self_attn_mask) + x

            if exists(cross_attn):
                assert exists(context)
@@ -686,7 +687,7 @@ class SemanticTransformer(nn.Module):
        ids = append_eos_id(ids, self.eos_id)

        if self.unique_consecutive:
            ids = batch_unique_consecutive(ids)
            ids = batch_unique_consecutive(ids, pad_value = self.pad_id)

        has_text = exists(text) or exists(text_embed)
        assert not (self.has_condition ^ has_text)
@@ -767,6 +768,7 @@ class CoarseTransformer(nn.Module):
        *,
        semantic_token_ids,
        coarse_token_ids,
        self_attn_mask = None,
        text = None,
        text_embed = None,
        cond_drop_prob = None
@@ -803,7 +805,7 @@ class CoarseTransformer(nn.Module):

        tokens = torch.cat((start_tokens, semantic_tokens, coarse_tokens), dim = 1)

        tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask)
        tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask)

        pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, semantic_seq_len:]

@@ -1039,15 +1041,16 @@ class CoarseTransformerWrapper(nn.Module):
        soundstream: Optional[SoundStream]  = None,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        num_coarse_quantize = 3,
        unique_consecutive = False
        pad_id = -1,
        unique_consecutive = True
    ):
        super().__init__()
        self.soundstream = soundstream
        self.wav2vec = wav2vec

        self.transformer = transformer

        assert not unique_consecutive, 'not implemented yet'
        self.unique_consecutive = unique_consecutive
        self.pad_id = pad_id

        assert num_coarse_quantize > 0
        self.num_coarse_quantize = num_coarse_quantize
@@ -1083,13 +1086,23 @@ class CoarseTransformerWrapper(nn.Module):
        coarse_token_ids = append_eos_id(coarse_token_ids, self.transformer.coarse_eos_id)
        semantic_token_ids = append_eos_id(semantic_token_ids, self.transformer.semantic_eos_id)

        if self.unique_consecutive:
            semantic_token_ids = batch_unique_consecutive(semantic_token_ids, pad_value = self.pad_id)

        if return_loss:
            semantic_labels, coarse_labels = semantic_token_ids, coarse_token_ids.clone()
            coarse_token_ids = coarse_token_ids[:, :-1]

        self_attn_mask = None
        if self.unique_consecutive:
            self_attn_mask = semantic_token_ids != -1
            semantic_token_ids = semantic_token_ids.masked_fill(~self_attn_mask, 0)
            self_attn_mask = F.pad(self_attn_mask, (1, coarse_token_ids.shape[-1]), value = True)

        semantic_logits, coarse_logits = self.transformer(
            semantic_token_ids = semantic_token_ids,
            coarse_token_ids = coarse_token_ids,
            self_attn_mask = self_attn_mask,
            **kwargs
        )

@@ -1098,11 +1111,15 @@ class CoarseTransformerWrapper(nn.Module):

        coarse_logits, semantic_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, semantic_logits))

        if self.unique_consecutive:
            num_coarse_logits, num_semantic_logits = coarse_logits.shape[0] * coarse_logits.shape[-1], self_attn_mask.sum()
        else:
            num_coarse_logits, num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1]

        semantic_loss = F.cross_entropy(
            semantic_logits,
            semantic_labels
            semantic_labels,
            ignore_index = self.pad_id
        )

        coarse_loss = F.cross_entropy(
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.11',
  version = '0.0.12',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',