Commit 8e3d1979 authored by Phil Wang's avatar Phil Wang
Browse files

make dilation order in encoder and decoder configurable in soundstream

parent 36c39540
Loading
Loading
Loading
Loading
+15 −10
Original line number Diff line number Diff line
import functools
from itertools import cycle
from pathlib import Path
from functools import partial

@@ -230,26 +231,28 @@ def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7):
        CausalConv1d(chan_out, chan_out, 1),
    ))

def EncoderBlock(chan_in, chan_out, stride):
def EncoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9)):
    it = cycle(cycle_dilations)
    return nn.Sequential(
        ResidualUnit(chan_in, chan_in, 1),
        ResidualUnit(chan_in, chan_in, 3),
        ResidualUnit(chan_in, chan_in, 9),
        ResidualUnit(chan_in, chan_in, next(it)),
        ResidualUnit(chan_in, chan_in, next(it)),
        ResidualUnit(chan_in, chan_in, next(it)),
        nn.ELU(),
        CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride)
    )

def DecoderBlock(chan_in, chan_out, stride):
def DecoderBlock(chan_in, chan_out, stride, cycle_dilations = (1, 3, 9)):
    even_stride = (stride % 2 == 0)
    padding = (stride + (0 if even_stride else 1)) // 2
    output_padding = 0 if even_stride else 1

    it = cycle(cycle_dilations)
    return nn.Sequential(
        nn.ELU(),
        CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride),
        ResidualUnit(chan_out, chan_out, 1),
        ResidualUnit(chan_out, chan_out, 3),
        ResidualUnit(chan_out, chan_out, 9),
        ResidualUnit(chan_out, chan_out, next(it)),
        ResidualUnit(chan_out, chan_out, next(it)),
        ResidualUnit(chan_out, chan_out, next(it)),
    )

class LocalTransformerBlock(nn.Module):
@@ -282,6 +285,8 @@ class SoundStream(nn.Module):
        rq_ema_decay = 0.95,
        input_channels = 1,
        discr_multi_scales = (1, 0.5, 0.25),
        enc_cycle_dilations = (1, 3, 9),
        dec_cycle_dilations = (1, 3, 9),
        recon_loss_weight = 1.,
        adversarial_loss_weight = 1.,
        feature_loss_weight = 100,
@@ -305,7 +310,7 @@ class SoundStream(nn.Module):
        encoder_blocks = []

        for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides):
            encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride))
            encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride, enc_cycle_dilations))

        self.encoder = nn.Sequential(
            CausalConv1d(input_channels, channels, 7),
@@ -341,7 +346,7 @@ class SoundStream(nn.Module):
        decoder_blocks = []

        for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)):
            decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride))
            decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride, dec_cycle_dilations))

        self.decoder = nn.Sequential(
            CausalConv1d(codebook_dim, layer_channels[-1], 7),
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.6.0',
  version = '0.6.1',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',