Commit 2c2d43d8 authored by Phil Wang's avatar Phil Wang
Browse files

just give coarse sequence in coarse transformer its own start token

parent 5cf4e945
Loading
Loading
Loading
Loading
+12 −14
Original line number Diff line number Diff line
@@ -544,7 +544,8 @@ class CoarseTransformer(nn.Module):
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob

        self.start_token = nn.Parameter(torch.randn(dim))
        self.semantic_start_token = nn.Parameter(torch.randn(dim))
        self.coarse_start_token = nn.Parameter(torch.randn(dim))

        self.semantic_eos_id = num_semantic_tokens
        self.semantic_embedding = nn.Embedding(num_semantic_tokens + 1, dim)
@@ -621,24 +622,20 @@ class CoarseTransformer(nn.Module):

        semantic_seq_len = semantic_tokens.shape[1]

        start_tokens = repeat(self.start_token, 'd -> b 1 d', b = b)
        semantic_start_tokens = repeat(self.semantic_start_token, 'd -> b 1 d', b = b)
        coarse_start_tokens = repeat(self.coarse_start_token, 'd -> b 1 d', b = b)

        tokens = torch.cat((start_tokens, semantic_tokens, coarse_tokens), dim = 1)
        tokens = torch.cat((
            semantic_start_tokens,
            semantic_tokens,
            coarse_start_tokens,
            coarse_tokens
        ), dim = 1)

        tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask)

        pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, (semantic_seq_len + 1):]

        # get the eos token from predicted semantic tokens, and use that to predict the first coarse token

        semantic_eos = semantic_token_ids == self.semantic_eos_id
        pred_semantic_eos_tokens = tokens[:, 1:(semantic_seq_len + 1)][semantic_eos]

        pred_coarse_tokens = torch.cat((
            rearrange(pred_semantic_eos_tokens, 'b d -> b 1 d'),
            pred_coarse_tokens),
        dim = 1)

        # semantic logits

        semantic_logits = self.to_semantic_logits(pred_semantic_tokens)
@@ -880,7 +877,8 @@ class CoarseTransformerWrapper(nn.Module):
        if self.unique_consecutive:
            self_attn_mask = semantic_token_ids != self.pad_id
            semantic_token_ids = semantic_token_ids.masked_fill(~self_attn_mask, 0)
            self_attn_mask = F.pad(self_attn_mask, (1, coarse_token_ids.shape[-1]), value = True)
            coarse_token_len = coarse_token_ids.shape[-1]
            self_attn_mask = F.pad(self_attn_mask, (1, coarse_token_len + 1), value = True)

        semantic_logits, coarse_logits = self.transformer(
            semantic_token_ids = semantic_token_ids,
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.33',
  version = '0.0.34',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',