Commit e4b66cf0 authored by Phil Wang's avatar Phil Wang
Browse files

add an autoregressive squeeze excitation module. go for the knockout

parent af903342
Loading
Loading
Loading
Loading
+53 −14
Original line number Diff line number Diff line
@@ -75,6 +75,11 @@ def gradient_penalty(wave, output, weight = 10):
    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean()

# better sequential

def Sequential(*mods):
    return nn.Sequential(*filter(exists, mods))

# discriminators

class MultiScaleDiscriminator(nn.Module):
@@ -124,6 +129,34 @@ class MultiScaleDiscriminator(nn.Module):

        return out, intermediates

# autoregressive squeeze excitation

class SqueezeExcite(nn.Module):
    def __init__(self, dim, reduction_factor = 4, dim_minimum = 8):
        super().__init__()
        dim_inner = max(dim_minimum, dim // reduction_factor)
        self.net = nn.Sequential(
            nn.Conv1d(dim, dim_inner, 1),
            nn.SiLU(),
            nn.Conv1d(dim_inner, dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        seq, device = x.shape[-2], x.device

        # cumulative mean - since it is autoregressive

        cum_sum = x.cumsum(dim = -2)
        denom = torch.arange(1, seq + 1, device = device).float()
        cum_mean = cum_sum / rearrange(denom, 'n -> n 1')

        # glu gate

        gate = self.net(cum_mean)

        return x * gate

# complex stft discriminator

class ModReLU(nn.Module):
@@ -284,34 +317,39 @@ class CausalConvTranspose1d(nn.Module):

        return out

def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7):
    return Residual(nn.Sequential(
def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False):
    return Residual(Sequential(
        CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation),
        nn.ELU(),
        CausalConv1d(chan_out, chan_out, 1),
        nn.ELU()
        nn.ELU(),
        SqueezeExcite(chan_out) if squeeze_excite else None
    ))

def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9)):
def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False):
    it = cycle(cycle_dilations)
    residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite)

    return nn.Sequential(
        ResidualUnit(chan_in, chan_in, next(it)),
        ResidualUnit(chan_in, chan_in, next(it)),
        ResidualUnit(chan_in, chan_in, next(it)),
        residual_unit(chan_in, chan_in, next(it)),
        residual_unit(chan_in, chan_in, next(it)),
        residual_unit(chan_in, chan_in, next(it)),
        CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride)
    )

def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9)):
def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False):
    even_stride = (stride % 2 == 0)
    padding = (stride + (0 if even_stride else 1)) // 2
    output_padding = 0 if even_stride else 1

    residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite)

    it = cycle(cycle_dilations)
    return nn.Sequential(
        CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride),
        ResidualUnit(chan_out, chan_out, next(it)),
        ResidualUnit(chan_out, chan_out, next(it)),
        ResidualUnit(chan_out, chan_out, next(it)),
        residual_unit(chan_out, chan_out, next(it)),
        residual_unit(chan_out, chan_out, next(it)),
        residual_unit(chan_out, chan_out, next(it)),
    )

class LocalTransformer(nn.Module):
@@ -383,7 +421,8 @@ class SoundStream(nn.Module):
        attn_heads = 8,
        attn_depth = 1,
        attn_xpos_scale_base = None,
        attn_dynamic_pos_bias = False
        attn_dynamic_pos_bias = False,
        squeeze_excite = False
    ):
        super().__init__()

@@ -408,7 +447,7 @@ class SoundStream(nn.Module):
        encoder_blocks = []

        for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides):
            encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations))
            encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations, squeeze_excite))

        self.encoder = nn.Sequential(
            CausalConv1d(input_channels, channels, 7),
@@ -450,7 +489,7 @@ class SoundStream(nn.Module):
        decoder_blocks = []

        for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)):
            decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations))
            decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations, squeeze_excite))

        self.decoder = nn.Sequential(
            CausalConv1d(codebook_dim, layer_channels[-1], 7),
+1 −1
Original line number Diff line number Diff line
__version__ = '0.22.0'
__version__ = '0.22.1'