Commit 38175d4b authored by Phil Wang's avatar Phil Wang
Browse files

fix a bug with residual quantize dropout, and also figure out a way to deal...

fix a bug with residual quantize dropout, and also figure out a way to deal with the input sequence needing to have a sequence length that is a multiple of the cumulative product of the strides in soundstream
parent b99a260f
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
import math
import functools
from functools import partial
from typing import Optional, Union

@@ -298,6 +299,7 @@ class SoundStream(nn.Module):
    ):
        super().__init__()
        self.single_channel = input_channels == 1
        self.strides = strides

        layer_channels = tuple(map(lambda t: t * channels, channel_mults))
        layer_channels = (channels, *layer_channels)
@@ -348,6 +350,10 @@ class SoundStream(nn.Module):
        self.adversarial_loss_weight = adversarial_loss_weight
        self.feature_loss_weight = feature_loss_weight

    @property
    def seq_len_multiple_of(self):
        return functools.reduce(lambda x, y: x * y, self.strides)

    def forward(
        self,
        x,
+11 −2
Original line number Diff line number Diff line
@@ -12,7 +12,8 @@ class SoundDataset(Dataset):
    def __init__(
        self,
        folder,
        exts = ['flac', 'wav']
        exts = ['flac', 'wav'],
        seq_len_multiple_of = None
    ):
        super().__init__()
        path = Path(folder)
@@ -20,7 +21,9 @@ class SoundDataset(Dataset):

        files = [file for ext in exts for file in path.glob(f'**/*.{ext}')]
        assert len(files) > 0, 'no sound files found'

        self.files = files
        self.seq_len_multiple_of = seq_len_multiple_of

    def __len__(self):
        return len(self.files)
@@ -28,7 +31,13 @@ class SoundDataset(Dataset):
    def __getitem__(self, idx):
        file = self.files[idx]
        data, _ = sf.read(file)
        return torch.from_numpy(data)

        if self.seq_len_multiple_of:
            mult = self.seq_len_multiple_of
            data_len = len(data)
            data = data[:(data_len // mult * mult)]

        return torch.from_numpy(data).float()

# dataloader functions

+2 −2
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.14',
  version = '0.0.15',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',
@@ -27,7 +27,7 @@ setup(
    'torch>=1.6',
    'torchaudio',
    'transformers',
    'vector-quantize-pytorch>=0.10.5'
    'vector-quantize-pytorch>=0.10.10'
  ],
  classifiers=[
    'Development Status :: 4 - Beta',