Commit daeedb27 authored by Phil Wang's avatar Phil Wang
Browse files

make sure to address the discriminator gradient issues uncovered by...

make sure to address the discriminator gradient issues uncovered by @apoorv2904 for multiscale discriminators too
parent ec642cd9
Loading
Loading
Loading
Loading
+9 −4
Original line number Diff line number Diff line
@@ -266,6 +266,10 @@ class SoundStreamTrainer(nn.Module):
        for ind, discr in enumerate(self.unwrapped_soundstream.discriminators):
            yield f'multiscale_discr_optimizer_{ind}', discr

    def multiscale_discriminator_optim_iter(self):
        for name, _ in self.multiscale_discriminator_iter():
            yield name, getattr(self, name)

    def print(self, msg):
        self.accelerator.print(msg)

@@ -322,6 +326,9 @@ class SoundStreamTrainer(nn.Module):

        self.discr_optim.zero_grad()

        for name, multiscale_discr_optim in self.multiscale_discriminator_optim_iter():
            multiscale_discr_optim.zero_grad()

        for _ in range(self.grad_accum_every):
            wave, = next(self.dl_iter)
            wave = wave.to(device)
@@ -344,10 +351,8 @@ class SoundStreamTrainer(nn.Module):

        self.discr_optim.step()

        for ind in range(len(self.soundstream.discriminators)):
            discr_optimizer = getattr(self, f'multiscale_discr_optimizer_{ind}')
            discr_optimizer.step()
            discr_optimizer.zero_grad()
        for name, multiscale_discr_optim in self.multiscale_discriminator_optim_iter():
            multiscale_discr_optim.step()

        # build pretty printed losses

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.5.0',
  version = '0.5.1',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',