Commit 0806f62b authored by Phil Wang's avatar Phil Wang
Browse files

only encode raw text once when generating

parent d9152df6
Loading
Loading
Loading
Loading
+17 −3
Original line number Diff line number Diff line
@@ -353,6 +353,8 @@ class SemanticTransformer(nn.Module):
        self,
        *,
        max_length,
        text = None,
        text_embeds = None,
        prime_wave = None,
        prime_ids = None,
        batch_size = 1,
@@ -377,6 +379,15 @@ class SemanticTransformer(nn.Module):
        if self.unique_consecutive:
            ids = batch_unique_consecutive(ids, pad_value = self.pad_id)

        # derive text embeddings if needed

        has_text = exists(text) or exists(text_embeds)
        assert not (self.has_condition ^ has_text)

        if not exists(text_embeds) and exists(text):
            with torch.no_grad():
                text_embeds = self.embed_text(text, output_device = device)

        # start length and get running id output

        start_length = ids.shape[-1]
@@ -388,6 +399,7 @@ class SemanticTransformer(nn.Module):

            logits = self.forward_with_cond_scale(
                ids = output,
                text_embeds = text_embeds,
                **kwargs
            )

@@ -401,7 +413,7 @@ class SemanticTransformer(nn.Module):
            if all_rows_have_eos_id(output, self.eos_id):
                break

        output = mask_out_after_eos_id(output, self.eos_id)
        output = mask_out_after_eos_id(output, self.pad_id)
        return output

    def forward_with_cond_scale(
@@ -559,10 +571,12 @@ class CoarseTransformer(nn.Module):
        has_text = exists(text) or exists(text_embeds)
        assert not (self.has_condition ^ has_text)

        text_mask = None
        if not exists(text_embeds) and exists(text):
            with torch.no_grad():
                text_embeds = self.embed_text(text, output_device = device)

        text_mask = None
        if exists(text_embeds):
            text_mask = torch.any(text_embeds != 0, dim = -1)

        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.29',
  version = '0.0.30',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',