Commit 2f4407b4 authored by Phil Wang's avatar Phil Wang
Browse files

make sure one can also finely specify the max length of each target sample freq

parent f7c26f3b
Loading
Loading
Loading
Loading
+22 −14
Original line number Diff line number Diff line
@@ -25,6 +25,10 @@ def exists(val):
def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# type

OptionalIntOrTupleInt = Optional[Union[int, Tuple[Optional[int], ...]]]

# dataset functions

@beartype
@@ -33,9 +37,9 @@ class SoundDataset(Dataset):
        self,
        folder,
        exts = ['flac', 'wav'],
        max_length = None,
        target_sample_hz: Optional[Union[int, Tuple[Optional[int], ...]]] = None,
        seq_len_multiple_of = None
        max_length: OptionalIntOrTupleInt = None,
        target_sample_hz: OptionalIntOrTupleInt = None,
        seq_len_multiple_of: OptionalIntOrTupleInt = None
    ):
        super().__init__()
        path = Path(folder)
@@ -45,10 +49,14 @@ class SoundDataset(Dataset):
        assert len(files) > 0, 'no sound files found'

        self.files = files
        self.max_length = max_length

        self.target_sample_hz = cast_tuple(target_sample_hz)
        self.seq_len_multiple_of = seq_len_multiple_of
        num_outputs = len(self.target_sample_hz)

        self.max_length = cast_tuple(max_length, num_outputs)
        self.seq_len_multiple_of = cast_tuple(seq_len_multiple_of, num_outputs)

        assert len(self.max_length) == len(self.target_sample_hz) == len(self.seq_len_multiple_of)

    def __len__(self):
        return len(self.files)
@@ -69,26 +77,26 @@ class SoundDataset(Dataset):

        # process each of the data resample at different frequencies individually

        for data in data_tuple:
        for data, max_length, seq_len_multiple_of in zip(data_tuple, self.max_length, self.seq_len_multiple_of):
            audio_length = data.size(1)

            # pad or curtail

            if audio_length > self.max_length:
                max_start = audio_length - self.max_length
            if audio_length > max_length:
                max_start = audio_length - max_length
                start = torch.randint(0, max_start, (1, ))
                data = data[:, start:start + self.max_length]
                data = data[:, start:start + max_length]

            else:
                data = F.pad(data, (0, self.max_length - audio_length), 'constant')
                data = F.pad(data, (0, max_length - audio_length), 'constant')

            data = rearrange(data, '1 ... -> ...')

            if exists(self.max_length):
                data = data[:self.max_length]
            if exists(max_length):
                data = data[:max_length]

            if exists(self.seq_len_multiple_of):
                data = curtail_to_multiple(data, self.seq_len_multiple_of)
            if exists(seq_len_multiple_of):
                data = curtail_to_multiple(data, seq_len_multiple_of)

            output.append(data.float())

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.4.5',
  version = '0.4.6',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',