Commit f7c26f3b authored by Phil Wang's avatar Phil Wang
Browse files

fix order in which audio is resampled vs pad / curtailed

parent 7f788a62
Loading
Loading
Loading
Loading
+34 −18
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ import torchaudio
from torchaudio.functional import resample

import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

@@ -57,35 +58,50 @@ class SoundDataset(Dataset):

        data, sample_hz = torchaudio.load(file)
        
        if data.size(1) > self.max_length:
            max_start = data.size(1) - self.max_length
        num_outputs = len(self.target_sample_hz)
        data = cast_tuple(data, num_outputs)

        # resample if target_sample_hz is not None in the tuple

        data_tuple = tuple((resample(d, sample_hz, target_sample_hz) if exists(target_sample_hz) else d) for d, target_sample_hz in zip(data, self.target_sample_hz))

        output = []

        # process each of the data resample at different frequencies individually

        for data in data_tuple:
            audio_length = data.size(1)

            # pad or curtail

            if audio_length > self.max_length:
                max_start = audio_length - self.max_length
                start = torch.randint(0, max_start, (1, ))
                data = data[:, start:start + self.max_length]

            else:
            data = torch.nn.functional.pad(data, (0, self.max_length - data.size(1)), 'constant')
                data = F.pad(data, (0, self.max_length - audio_length), 'constant')

            data = rearrange(data, '1 ... -> ...')

        num_outputs = len(self.target_sample_hz)
        data = cast_tuple(data, num_outputs)
            if exists(self.max_length):
                data = data[:self.max_length]

        # resample if target_sample_hz is not None in the tuple
            if exists(self.seq_len_multiple_of):
                data = curtail_to_multiple(data, self.seq_len_multiple_of)

        data = tuple((resample(d, sample_hz, target_sample_hz) if exists(target_sample_hz) else d) for d, target_sample_hz in zip(data, self.target_sample_hz))
            output.append(data.float())

        if exists(self.max_length):
            data = tuple(d[:self.max_length] for d in data)
        # cast from list to tuple

        if exists(self.seq_len_multiple_of):
            data = tuple(curtail_to_multiple(d, self.seq_len_multiple_of) for d in data)
        output = tuple(output)

        data = tuple(d.float() for d in data)
        # return only one audio, if only one target resample freq

        if num_outputs == 1:
            return data[0]
            return output[0]

        return data
        return output

# dataloader functions

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.4.4',
  version = '0.4.5',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',