Commit d05c020f authored by Phil Wang's avatar Phil Wang
Browse files

allow for a custom reconstruction target in soundstream

parent c3fabc36
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -651,6 +651,7 @@ class SoundStream(nn.Module):
    def forward(
        self,
        x,
        target = None,
        return_encoded = False,
        return_discr_loss = False,
        return_discr_losses_separately = False,
@@ -752,7 +753,9 @@ class SoundStream(nn.Module):

        # recon loss

        recon_loss = F.mse_loss(orig_x, recon_x)
        target = default(target, orig_x)  # target can also be passed in, in the case of denoising

        recon_loss = F.mse_loss(target, recon_x)

        # multispectral recon loss - eq (4) and (5) in https://arxiv.org/abs/2107.03312

+1 −1
Original line number Diff line number Diff line
__version__ = '0.28.1'
__version__ = '0.28.2'