Commit ce717d99 authored by Phil Wang's avatar Phil Wang
Browse files

whether to use the absolute value of the complex logits for output of complex...

whether to use the absolute value of the complex logits for output of complex stft discriminator, or simply view as real - thanks to @ilya16 for noting this detail
parent 3c639577
Loading
Loading
Loading
Loading
+15 −6
Original line number Diff line number Diff line
@@ -214,7 +214,8 @@ class ComplexSTFTDiscriminator(nn.Module):
        n_fft = 1024,
        hop_length = 256,
        win_length = 1024,
        stft_normalized = False
        stft_normalized = False,
        logits_abs = True
    ):
        super().__init__()
        self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3)
@@ -240,6 +241,9 @@ class ComplexSTFTDiscriminator(nn.Module):
        self.hop_length = hop_length
        self.win_length = win_length

        # how to output the logits into real space
        self.logits_abs = logits_abs

    def forward(self, x, return_intermediates = False):
        x = rearrange(x, 'b 1 n -> b n')

@@ -273,12 +277,15 @@ class ComplexSTFTDiscriminator(nn.Module):

        complex_logits = self.final_conv(x)

        complex_logits_abs = torch.abs(complex_logits)
        if self.logits_abs:
            complex_logits = complex_logits.abs()
        else:
            complex_logits = torch.view_as_real(complex_logits)

        if not return_intermediates:
            return complex_logits_abs
            return complex_logits

        return complex_logits_abs, intermediates
        return complex_logits, intermediates

# sound stream

@@ -423,7 +430,8 @@ class SoundStream(nn.Module):
        attn_depth = 1,
        attn_xpos_scale_base = None,
        attn_dynamic_pos_bias = False,
        squeeze_excite = False
        squeeze_excite = False,
        complex_stft_discr_logits_abs = True
    ):
        super().__init__()

@@ -506,7 +514,8 @@ class SoundStream(nn.Module):
        self.downsamples = nn.ModuleList([nn.Identity()] + [nn.AvgPool1d(2 * factor, stride = factor, padding = factor) for factor in discr_rel_factors])

        self.stft_discriminator = ComplexSTFTDiscriminator(
            stft_normalized = stft_normalized
            stft_normalized = stft_normalized,
            logits_abs = complex_stft_discr_logits_abs  # whether to output as abs() or use view_as_real
        )

        # multi spectral reconstruction
+1 −1
Original line number Diff line number Diff line
__version__ = '0.23.5'
__version__ = '0.23.6'