Commit bc626a64 authored by Phil Wang's avatar Phil Wang
Browse files

prepare for variable lengthed coarse tokens in fine transformer as well 0.0.36

parent 2c2d43d8
Loading
Loading
Loading
Loading
+23 −7
Original line number Diff line number Diff line
@@ -683,7 +683,8 @@ class FineTransformer(nn.Module):
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob

        self.start_token = nn.Parameter(torch.randn(dim))
        self.coarse_start_token = nn.Parameter(torch.randn(dim))
        self.fine_start_token = nn.Parameter(torch.randn(dim))

        codebook_size_with_eos = codebook_size + 1

@@ -725,7 +726,8 @@ class FineTransformer(nn.Module):
        fine_token_ids,
        text = None,
        text_embeds = None,
        cond_drop_prob = None
        cond_drop_prob = None,
        self_attn_mask = None
    ):
        b, device = coarse_token_ids.shape[0], coarse_token_ids.device
        has_text = exists(text) or exists(text_embeds)
@@ -760,13 +762,19 @@ class FineTransformer(nn.Module):
        coarse_tokens = self.coarse_embedding(coarse_token_ids)
        fine_tokens = self.fine_embedding(fine_token_ids)

        start_tokens = repeat(self.start_token, 'd -> b 1 d', b = b)
        coarse_start_tokens = repeat(self.coarse_start_token, 'd -> b 1 d', b = b)
        fine_start_tokens = repeat(self.fine_start_token, 'd -> b 1 d', b = b)

        tokens = torch.cat((start_tokens, coarse_tokens, fine_tokens), dim = 1)
        tokens = torch.cat((
            coarse_start_tokens,
            coarse_tokens,
            fine_start_tokens,
            fine_tokens
        ), dim = 1)

        tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask)
        tokens = self.transformer(tokens, context = text_embeds, self_attn_mask =self_attn_mask, context_mask = text_mask)

        pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, n:]
        pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, (n + 1):]

        # get coarse logits

@@ -916,7 +924,8 @@ class FineTransformerWrapper(nn.Module):
        *,
        transformer: FineTransformer,
        soundstream: Optional[SoundStream] = None,
        num_coarse_quantize = 3
        num_coarse_quantize = 3,
        pad_id = -1
    ):
        super().__init__()
        self.soundstream = soundstream
@@ -924,6 +933,7 @@ class FineTransformerWrapper(nn.Module):

        assert num_coarse_quantize > 0
        self.num_coarse_quantize = num_coarse_quantize
        self.pad_id = pad_id

    def forward(
        self,
@@ -955,9 +965,15 @@ class FineTransformerWrapper(nn.Module):
            coarse_labels, fine_labels = coarse_token_ids, fine_token_ids.clone()
            fine_token_ids = fine_token_ids[:, :-1]

        self_attn_mask = coarse_token_ids != self.pad_id

        fine_token_seq_len = fine_token_ids.shape[-1]
        self_attn_mask = F.pad(self_attn_mask, (1, fine_token_seq_len + 1), value = True)

        coarse_logits, fine_logits = self.transformer(
            coarse_token_ids = coarse_token_ids,
            fine_token_ids = fine_token_ids,
            self_attn_mask = self_attn_mask,
            **kwargs
        )

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.34',
  version = '0.0.36',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',