Commit f760d083 authored by Phil Wang's avatar Phil Wang
Browse files

allow for sound dataset to return multiple resample audio, since for training...

allow for sound dataset to return multiple resample audio, since for training the coarse transformer, the same audio needs to be resampled differently for the hubert or vq-wav2vec vs the soundstream coarse token ids
parent 52cfee78
Loading
Loading
Loading
Loading
+40 −8
Original line number Diff line number Diff line
from pathlib import Path
from functools import partial
from functools import partial, wraps

import torchaudio
from torchaudio.functional import resample

import torch
from torch.nn.utils.rnn import pad_sequence
@@ -10,9 +12,14 @@ from audiolm_pytorch.utils import curtail_to_multiple

from einops import rearrange

# helper functions

def exists(val):
    return val is not None

def cast_tuple(val, length = 1):
    return val if isinstance(val, tuple) else ((val,) * length)

# dataset functions

class SoundDataset(Dataset):
@@ -34,7 +41,7 @@ class SoundDataset(Dataset):
        self.files = files
        self.max_length = max_length

        self.target_sample_hz = target_sample_hz
        self.target_sample_hz = cast_tuple(target_sample_hz)
        self.seq_len_multiple_of = seq_len_multiple_of

    def __len__(self):
@@ -42,28 +49,53 @@ class SoundDataset(Dataset):

    def __getitem__(self, idx):
        file = self.files[idx]
        data, sample_hz = torchaudio.load(file)

        data, sample_hz = torchaudio.load(file)
        data = rearrange(data, '1 ... -> ...')

        num_outputs = len(self.target_sample_hz)
        data = cast_tuple(data, num_outputs)

        if exists(self.target_sample_hz):
            data = torchaudio.functional.resample(data, sample_hz, self.target_sample_hz)
            data = tuple(resample(d, sample_hz, target_sample_hz) for d, target_sample_hz in zip(data, self.target_sample_hz))

        if exists(self.max_length):
            data = data[:self.max_length]
            data = tuple(d[:self.max_length] for d in data)

        if exists(self.seq_len_multiple_of):
            data = curtail_to_multiple(data, self.seq_len_multiple_of)
            data = tuple(curtail_to_multiple(d, self.seq_len_multiple_of) for d in data)

        data = tuple(d.float() for d in data)

        return data.float()
        if num_outputs == 1:
            return data[0]

        return data

# dataloader functions

def collate_one_or_multiple_tensors(fn):
    @wraps(fn)
    def inner(data):
        is_one_data = not isinstance(data[0], tuple)

        if is_one_data:
            return fn(data)

        return tuple(map(fn, zip(*data)))

    return inner

@collate_one_or_multiple_tensors
def curtail_to_shortest_collate(data):
    min_len = min(*[datum.shape[0] for datum in data])
    data = [datum[:min_len] for datum in data]
    return torch.stack(data)

@collate_one_or_multiple_tensors
def pad_to_longest(data):
    return pad_sequence(data, batch_first = True)

def get_dataloader(ds, pad_to_longest = True, **kwargs):
    collate_fn = partial(pad_sequence, batch_first = True) if pad_to_longest else curtail_to_shortest_collate
    collate_fn = pad_to_longest if pad_to_longest else curtail_to_shortest_collate
    return DataLoader(ds, collate_fn = collate_fn, **kwargs)
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.47',
  version = '0.0.48',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',