Commit d0151be9 authored by Phil Wang's avatar Phil Wang
Browse files

some tweaks to soundstream, include commitment loss

parent 894d8925
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -30,6 +30,7 @@ trainer = SoundStreamTrainer(
    soundstream,
    folder = '/path/to/librispeech',
    batch_size = 4,
    grad_accum_every = 8,         # effective batch size of 32
    data_max_length = 320 * 32,
    num_train_steps = 10000
).cuda()
+14 −2
Original line number Diff line number Diff line
@@ -247,12 +247,13 @@ class SoundStream(nn.Module):
        codebook_dim = 512,
        codebook_size = 1024,
        rq_num_quantizers = 8,
        rq_commitment_weight = 1.,
        input_channels = 1,
        discr_multi_scales = (1, 0.5, 0.25),
        recon_loss_weight = 1.,
        adversarial_loss_weight = 1.,
        feature_loss_weight = 100,
        quantize_dropout_cutoff_index = 0,
        quantize_dropout_cutoff_index = 1,
        target_sample_hz = 24000
    ):
        super().__init__()
@@ -280,6 +281,7 @@ class SoundStream(nn.Module):
            dim = codebook_dim,
            num_quantizers = rq_num_quantizers,
            codebook_size = codebook_size,
            commitment_weight = rq_commitment_weight,
            kmeans_init = True,
            threshold_ema_dead_code = 2,
            quantize_dropout = True,
@@ -334,6 +336,7 @@ class SoundStream(nn.Module):
        return_encoded = False,
        return_discr_loss = False,
        return_discr_losses_separately = False,
        return_loss_breakdown = False,
        return_recons_only = False,
        input_sample_hz = None,
        apply_grad_penalty = False
@@ -459,4 +462,13 @@ class SoundStream(nn.Module):

        adversarial_loss = torch.stack(adversarial_losses).mean()

        return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight
        # sum commitment loss

        all_commitment_loss = commit_loss.sum()

        total_loss = recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight + all_commitment_loss

        if return_loss_breakdown:
            return total_loss, (recon_loss, adversarial_loss, feature_loss, all_commitment_loss)

        return total_loss
+8 −5
Original line number Diff line number Diff line
@@ -77,7 +77,7 @@ class SoundStreamTrainer(nn.Module):
        data_max_length = None,
        folder,
        lr = 3e-4,
        grad_accum_every = 1,
        grad_accum_every = 4,
        wd = 0.,
        max_grad_norm = 0.5,
        discr_max_grad_norm = None,
@@ -226,11 +226,14 @@ class SoundStreamTrainer(nn.Module):
            wave = next(self.dl_iter)
            wave = wave.to(device)

            loss = self.soundstream(wave)
            loss, (recon_loss, *_) = self.soundstream(wave, return_loss_breakdown = True)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
            accum_log(logs, dict(
                loss = loss.item() / self.grad_accum_every,
                recon_loss = recon_loss / self.grad_accum_every
            ))

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.soundstream.parameters(), self.max_grad_norm)
@@ -270,14 +273,14 @@ class SoundStreamTrainer(nn.Module):

        # build pretty printed losses

        losses_str = f"{steps}: soundstream loss: {logs['loss']}"
        losses_str = f"{steps}: soundstream total loss: {logs['loss']:.3f}, soundstream recon loss: {logs['recon_loss']:.3f}"

        for key, loss in logs.items():
            if not key.startswith('scale:'):
                continue
            _, scale_factor = key.split(':')

            losses_str += f" | discr (scale {scale_factor}) loss: {loss:.2f}"
            losses_str += f" | discr (scale {scale_factor}) loss: {loss:.3f}"

        # log

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.1.8',
  version = '0.1.9',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',