Commit b0384e83 authored by Phil Wang's avatar Phil Wang
Browse files

fix soundstream training

parent 8ebbb2e9
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -263,7 +263,7 @@ class SoundStreamTrainer(nn.Module):
        # update vae (generator)

        for _ in range(self.grad_accum_every):
            wave = next(self.dl_iter)
            wave, = next(self.dl_iter)
            wave = wave.to(device)

            loss, (recon_loss, *_) = self.soundstream(wave, return_loss_breakdown = True)
@@ -284,7 +284,7 @@ class SoundStreamTrainer(nn.Module):
        # update discriminator

        for _ in range(self.grad_accum_every):
            wave = next(self.dl_iter)
            wave, = next(self.dl_iter)
            wave = wave.to(device)

            discr_losses = self.soundstream(
@@ -337,7 +337,7 @@ class SoundStreamTrainer(nn.Module):
            for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.soundstream, str(steps))):
                model.eval()

                wave = next(self.valid_dl_iter)
                wave, = next(self.valid_dl_iter)
                wave = wave.to(device)

                recons = model(wave, return_recons_only = True)
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.1.19',
  version = '0.1.20',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',