Commit 5cf4e945 authored by Phil Wang's avatar Phil Wang
Browse files

ensure semantic transformer generations always include an eos for each sequence

parent 70f02c5b
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -424,6 +424,20 @@ class SemanticTransformer(nn.Module):
            last_logit_indices += 1

        output = mask_out_after_eos_id(output, self.pad_id, include_eos = include_eos_in_output)

        # ensure all sequences have eos

        has_eos_mask = (output == self.eos_id).any(dim = -1)

        if not has_eos_mask.all():
            append_eos_or_pad = torch.where(
                has_eos_mask,
                torch.full((batch, 1), self.pad_id, dtype = torch.long, device = device),
                torch.full((batch, 1), self.eos_id, dtype = torch.long, device = device),
            )

            output = torch.cat((output, append_eos_or_pad), dim = -1)

        return output

    def forward_with_cond_scale(
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.32',
  version = '0.0.33',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',