Commit 5f964966 authored by Phil Wang's avatar Phil Wang
Browse files

properly accelerator prepare all multiscale discriminators

parent 4a206d7b
Loading
Loading
Loading
Loading
+15 −2
Original line number Diff line number Diff line
@@ -98,9 +98,9 @@ class SoundStreamTrainer(nn.Module):

        self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd)

        for ind, discr in enumerate(soundstream.discriminators):
        for discr_optimizer_key, discr in self.multiscale_discriminator_iter():
            one_multiscale_discr_optimizer = get_optimizer(discr.parameters(), lr = lr, wd = wd)
            setattr(self, f'multiscale_discr_optimizer_{ind}', one_multiscale_discr_optimizer)
            setattr(self, discr_optimizer_key, one_multiscale_discr_optimizer)

        self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd)

@@ -150,6 +150,15 @@ class SoundStreamTrainer(nn.Module):
            self.valid_dl
        )

        # prepare the multiscale discriminators with accelerator

        for name, _ in self.multiscale_discriminator_iter():
            optimizer = getattr(self, name)
            optimizer = self.accelerator.prepare(optimizer)
            setattr(self, name, optimizer)

        # dataloader iterators

        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

@@ -165,6 +174,10 @@ class SoundStreamTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)

    def multiscale_discriminator_iter(self):
        for ind, discr in enumerate(self.soundstream.discriminators):
            yield f'multiscale_discr_optimizer_{ind}', discr

    def print(self, msg):
        self.accelerator.print(msg)

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.18',
  version = '0.0.19',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',