Commit 99bf7cb6 authored by Phil Wang's avatar Phil Wang
Browse files

prepare soundstream for latent diffusion in natural speech 2 repository

parent f837c787
Loading
Loading
Loading
Loading
+10 −4
Original line number Diff line number Diff line
@@ -422,7 +422,7 @@ class SoundStream(nn.Module):
        multi_spectral_n_ffts = 512,
        multi_spectral_n_mels = 64,
        recon_loss_weight = 1.,
        multi_spectral_recon_loss_weight = 1.,
        multi_spectral_recon_loss_weight = 1e-5,
        adversarial_loss_weight = 1.,
        feature_loss_weight = 100,
        quantize_dropout_cutoff_index = 1,
@@ -568,6 +568,12 @@ class SoundStream(nn.Module):
        codes = self.rq.get_codes_from_indices(quantized_indices)
        x = reduce(codes, 'q ... -> ...', 'sum')

        return self.decode(x)

    def decode(self, x, quantize = False):
        if quantize:
            x, *_ = self.rq(x)

        x = self.decoder_attn(x)
        x = rearrange(x, 'b n c -> b c n')
        return self.decoder(x)
@@ -664,14 +670,14 @@ class SoundStream(nn.Module):

        x, indices, commit_loss = self.rq(x)

        if return_encoded:
            return x, indices, commit_loss

        if exists(self.decoder_attn):
            x = self.decoder_attn(x)

        x = rearrange(x, 'b n c -> b c n')

        if return_encoded:
            return x, indices, commit_loss

        recon_x = self.decoder(x)

        if return_recons_only:
+1 −1
Original line number Diff line number Diff line
__version__ = '0.26.8'
__version__ = '0.27.0'