Loading audiolm_pytorch/soundstream.py +10 −4 Original line number Diff line number Diff line Loading @@ -422,7 +422,7 @@ class SoundStream(nn.Module): multi_spectral_n_ffts = 512, multi_spectral_n_mels = 64, recon_loss_weight = 1., multi_spectral_recon_loss_weight = 1., multi_spectral_recon_loss_weight = 1e-5, adversarial_loss_weight = 1., feature_loss_weight = 100, quantize_dropout_cutoff_index = 1, Loading Loading @@ -568,6 +568,12 @@ class SoundStream(nn.Module): codes = self.rq.get_codes_from_indices(quantized_indices) x = reduce(codes, 'q ... -> ...', 'sum') return self.decode(x) def decode(self, x, quantize = False): if quantize: x, *_ = self.rq(x) x = self.decoder_attn(x) x = rearrange(x, 'b n c -> b c n') return self.decoder(x) Loading Loading @@ -664,14 +670,14 @@ class SoundStream(nn.Module): x, indices, commit_loss = self.rq(x) if return_encoded: return x, indices, commit_loss if exists(self.decoder_attn): x = self.decoder_attn(x) x = rearrange(x, 'b n c -> b c n') if return_encoded: return x, indices, commit_loss recon_x = self.decoder(x) if return_recons_only: Loading audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.26.8' __version__ = '0.27.0' Loading
audiolm_pytorch/soundstream.py +10 −4 Original line number Diff line number Diff line Loading @@ -422,7 +422,7 @@ class SoundStream(nn.Module): multi_spectral_n_ffts = 512, multi_spectral_n_mels = 64, recon_loss_weight = 1., multi_spectral_recon_loss_weight = 1., multi_spectral_recon_loss_weight = 1e-5, adversarial_loss_weight = 1., feature_loss_weight = 100, quantize_dropout_cutoff_index = 1, Loading Loading @@ -568,6 +568,12 @@ class SoundStream(nn.Module): codes = self.rq.get_codes_from_indices(quantized_indices) x = reduce(codes, 'q ... -> ...', 'sum') return self.decode(x) def decode(self, x, quantize = False): if quantize: x, *_ = self.rq(x) x = self.decoder_attn(x) x = rearrange(x, 'b n c -> b c n') return self.decoder(x) Loading Loading @@ -664,14 +670,14 @@ class SoundStream(nn.Module): x, indices, commit_loss = self.rq(x) if return_encoded: return x, indices, commit_loss if exists(self.decoder_attn): x = self.decoder_attn(x) x = rearrange(x, 'b n c -> b c n') if return_encoded: return x, indices, commit_loss recon_x = self.decoder(x) if return_recons_only: Loading
audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.26.8' __version__ = '0.27.0'