Commit 4a206d7b authored by Phil Wang's avatar Phil Wang
Browse files

take care of training all multiscale discriminators

parent 07a7eb00
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -76,6 +76,7 @@ loss.backward()
- [ ] test with speech synthesis for starters
- [ ] abstract out conditioning + classifier free guidance into external module or potentially a package
- [ ] add option to use flash attention
- [ ] function for pretty printing all discriminator losses to log

## Citations

+30 −12
Original line number Diff line number Diff line
@@ -350,6 +350,9 @@ class SoundStream(nn.Module):
        self.adversarial_loss_weight = adversarial_loss_weight
        self.feature_loss_weight = feature_loss_weight

    def non_discr_parameters(self):
        return [*self.encoder.parameters(), *self.decoder.parameters()]

    @property
    def seq_len_multiple_of(self):
        return functools.reduce(lambda x, y: x * y, self.strides)
@@ -359,8 +362,8 @@ class SoundStream(nn.Module):
        x,
        return_encoded = False,
        return_discr_loss = False,
        return_recons_only = False,
        return_stft_discr_loss = False
        return_discr_losses_separately = False,
        return_recons_only = False
    ):
        if x.ndim == 2:
            x = rearrange(x, 'b n -> b 1 n')
@@ -381,21 +384,19 @@ class SoundStream(nn.Module):
        if return_recons_only:
            return recon_x

        # stft discr loss

        if return_stft_discr_loss:
            assert self.single_channel
            real, fake = orig_x, recon_x.detach()
            stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake))
            stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2
            return stft_discr_loss

        # multi-scale discriminator loss

        if return_discr_loss:
            real, fake = orig_x, recon_x.detach()

            stft_discr_loss = None
            discr_losses = []

            if self.single_channel:
                real, fake = orig_x, recon_x.detach()
                stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake))
                stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2

            for discr, scale in zip(self.discriminators, self.discr_multi_scales):
                scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake))

@@ -403,7 +404,24 @@ class SoundStream(nn.Module):
                one_discr_loss = hinge_discr_loss(fake_logits, real_logits)
                discr_losses.append(one_discr_loss)

            return torch.stack(discr_losses).mean()
            if not return_discr_losses_separately:
                all_discr_losses = torch.stack(discr_losses).mean()

                if exists(stft_discr_loss):
                    all_discr_losses = all_discr_losses + stft_discr_loss

                return all_discr_losses

            # return a list of discriminator losses with List[Tuple[str, Tensor]]

            discr_losses_pkg = []

            discr_losses_pkg.extend([(f'scale:{scale}', multi_scale_loss) for scale, multi_scale_loss in zip(self.discr_multi_scales, discr_losses)])

            if exists(stft_discr_loss):
                discr_losses_pkg.append(('stft', stft_discr_loss))

            return discr_losses_pkg

        # recon loss

+5 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ class SoundDataset(Dataset):
        self,
        folder,
        exts = ['flac', 'wav'],
        max_length = None,
        seq_len_multiple_of = None
    ):
        super().__init__()
@@ -28,6 +29,7 @@ class SoundDataset(Dataset):
        assert len(files) > 0, 'no sound files found'

        self.files = files
        self.max_length = max_length
        self.seq_len_multiple_of = seq_len_multiple_of

    def __len__(self):
@@ -39,6 +41,9 @@ class SoundDataset(Dataset):

        data = rearrange(data, '1 ... -> ...')

        if exists(self.max_length):
            data = data[:self.max_length]

        if exists(self.seq_len_multiple_of):
            mult = self.seq_len_multiple_of
            data_len = len(data)
+34 −19
Original line number Diff line number Diff line
@@ -64,6 +64,7 @@ class SoundStreamTrainer(nn.Module):
        *,
        num_train_steps,
        batch_size,
        data_max_length = None,
        folder,
        lr = 3e-4,
        grad_accum_every = 1,
@@ -93,14 +94,17 @@ class SoundStreamTrainer(nn.Module):
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        all_parameters = set(soundstream.parameters())
        discr_parameters = set(soundstream.stft_discriminator.parameters())
        soundstream_parameters = all_parameters - discr_parameters
        # optimizers

        self.soundstream_parameters = soundstream_parameters
        self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd)

        self.optim = get_optimizer(soundstream_parameters, lr = lr, wd = wd)
        self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
        for ind, discr in enumerate(soundstream.discriminators):
            one_multiscale_discr_optimizer = get_optimizer(discr.parameters(), lr = lr, wd = wd)
            setattr(self, f'multiscale_discr_optimizer_{ind}', one_multiscale_discr_optimizer)

        self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd)

        # max grad norm

        self.max_grad_norm = max_grad_norm
        self.discr_max_grad_norm = discr_max_grad_norm
@@ -109,6 +113,7 @@ class SoundStreamTrainer(nn.Module):

        self.ds = SoundDataset(
            folder,
            max_length = data_max_length,
            seq_len_multiple_of = soundstream.seq_len_multiple_of
        )

@@ -211,26 +216,36 @@ class SoundStreamTrainer(nn.Module):

        # update discriminator

        if exists(self.soundstream.stft_discriminator):
        for _ in range(self.grad_accum_every):
            wave = next(self.dl_iter)
            wave = wave.to(device)

                loss = self.soundstream(wave, return_discr_loss = True)

                self.accelerator.backward(loss / self.grad_accum_every)
            discr_losses = self.soundstream(
                wave,
                return_discr_loss = True,
                return_discr_losses_separately = True
            )

                accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
            for name, discr_loss in discr_losses:
                self.accelerator.backward(discr_loss / self.grad_accum_every)
                accum_log(logs, {name: discr_loss.item() / self.grad_accum_every})

        if exists(self.discr_max_grad_norm):
            self.accelerator.clip_grad_norm_(self.soundstream.stft_discriminator.parameters(), self.discr_max_grad_norm)

        # gradient step for all discriminators

        self.discr_optim.step()
        self.discr_optim.zero_grad()

        for ind in range(len(self.soundstream.discriminators)):
            discr_optimizer = getattr(self, f'multiscale_discr_optimizer_{ind}')
            discr_optimizer.step()
            discr_optimizer.zero_grad()

        # log

            self.print(f"{steps}: soundstream loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
        self.print(f"{steps}: soundstream loss: {logs['loss']}")

        # update exponential moving averaged generator

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.17',
  version = '0.0.18',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',