Commit d244a6c9 authored by Phil Wang's avatar Phil Wang
Browse files

take care of soundstream accepting audio without batch dimension

parent 255b578f
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -62,6 +62,11 @@ trainer = SoundStreamTrainer(
).cuda()

trainer.train()

# after a lot of training, you can test the autoencoding as so

audio = torch.randn(10080).cuda()
recons = soundstream(audio, return_recons_only = True) # (1, 10080) - 1 channel
```

Then three separate transformers (`SemanticTransformer`, `CoarseTransformer`, `FineTransformer`) need to be trained
+4 −2
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ from torch.linalg import vector_norm

import torchaudio.transforms as T

from einops import rearrange, reduce
from einops import rearrange, reduce, pack, unpack

from vector_quantize_pytorch import ResidualVQ

@@ -518,7 +518,6 @@ class SoundStream(nn.Module):
        assert path.exists()
        obj = torch.load(str(path))
        self.load_state_dict(obj['model'])
        exit()

    def non_discr_parameters(self):
        return [
@@ -543,6 +542,8 @@ class SoundStream(nn.Module):
        input_sample_hz = None,
        apply_grad_penalty = False
    ):
        x, ps = pack([x], '* n')

        if exists(input_sample_hz):
            x = resample(x, input_sample_hz, self.target_sample_hz)

@@ -573,6 +574,7 @@ class SoundStream(nn.Module):
        recon_x = self.decoder(x)

        if return_recons_only:
            recon_x, = unpack(recon_x, ps, '* c n')
            return recon_x

        # multi-scale discriminator loss
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.15.1',
  version = '0.15.2',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',