Commit 657790d6 authored by Phil Wang's avatar Phil Wang
Browse files

allow the padding mode to be customizable in soundstream for causal conv

parent 1f0aefe1
Loading
Loading
Loading
Loading
+17 −15
Original line number Diff line number Diff line
@@ -305,17 +305,18 @@ class Residual(nn.Module):
        return self.fn(x, **kwargs) + x

class CausalConv1d(nn.Module):
    def __init__(self, chan_in, chan_out, kernel_size, **kwargs):
    def __init__(self, chan_in, chan_out, kernel_size, pad_mode = 'reflect', **kwargs):
        super().__init__()
        kernel_size = kernel_size
        dilation = kwargs.get('dilation', 1)
        stride = kwargs.get('stride', 1)
        self.pad_mode = pad_mode
        self.causal_padding = dilation * (kernel_size - 1) + (1 - stride)

        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, **kwargs)

    def forward(self, x):
        x = F.pad(x, (self.causal_padding, 0), mode = 'reflect')
        x = F.pad(x, (self.causal_padding, 0), mode = self.pad_mode)
        return self.conv(x)

class CausalConvTranspose1d(nn.Module):
@@ -333,18 +334,18 @@ class CausalConvTranspose1d(nn.Module):

        return out

def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False):
def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7, squeeze_excite = False, pad_mode = 'reflect'):
    return Residual(Sequential(
        CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation),
        CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation, pad_mode = pad_mode),
        nn.ELU(),
        CausalConv1d(chan_out, chan_out, 1),
        CausalConv1d(chan_out, chan_out, 1, pad_mode = pad_mode),
        nn.ELU(),
        SqueezeExcite(chan_out) if squeeze_excite else None
    ))

def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False):
def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'):
    it = cycle(cycle_dilations)
    residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite)
    residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode)

    return nn.Sequential(
        residual_unit(chan_in, chan_in, next(it)),
@@ -353,12 +354,12 @@ def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze
        CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride)
    )

def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False):
def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9), squeeze_excite = False, pad_mode = 'reflect'):
    even_stride = (stride % 2 == 0)
    padding = (stride + (0 if even_stride else 1)) // 2
    output_padding = 0 if even_stride else 1

    residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite)
    residual_unit = partial(ResidualUnit, squeeze_excite = squeeze_excite, pad_mode = pad_mode)

    it = cycle(cycle_dilations)
    return nn.Sequential(
@@ -450,6 +451,7 @@ class SoundStream(nn.Module):
        attn_dynamic_pos_bias = False,
        squeeze_excite = False,
        complex_stft_discr_logits_abs = True,
        pad_mode = 'reflect',
        stft_discriminator: Optional[nn.Module] = None  # can pass in own stft discriminator
    ):
        super().__init__()
@@ -475,12 +477,12 @@ class SoundStream(nn.Module):
        encoder_blocks = []

        for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides):
            encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations, squeeze_excite))
            encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations, squeeze_excite, pad_mode))

        self.encoder = nn.Sequential(
            CausalConv1d(input_channels, channels, 7),
            CausalConv1d(input_channels, channels, 7, pad_mode = pad_mode),
            *encoder_blocks,
            CausalConv1d(layer_channels[-1], codebook_dim, 3)
            CausalConv1d(layer_channels[-1], codebook_dim, 3, pad_mode = pad_mode)
        )

        attn_kwargs = dict(
@@ -526,12 +528,12 @@ class SoundStream(nn.Module):
        decoder_blocks = []

        for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)):
            decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations, squeeze_excite))
            decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations, squeeze_excite, pad_mode))

        self.decoder = nn.Sequential(
            CausalConv1d(codebook_dim, layer_channels[-1], 7),
            CausalConv1d(codebook_dim, layer_channels[-1], 7, pad_mode = pad_mode),
            *decoder_blocks,
            CausalConv1d(channels, input_channels, 7)
            CausalConv1d(channels, input_channels, 7, pad_mode = pad_mode)
        )

        # discriminators
+1 −1
Original line number Diff line number Diff line
__version__ = '0.30.0'
__version__ = '0.30.1'