Commit 9733c558 authored by Phil Wang's avatar Phil Wang
Browse files

add ability to set length of audio being trained on in seconds `data_max_length_seconds`

parent c53245de
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -57,7 +57,7 @@ trainer = SoundStreamTrainer(
    folder = '/path/to/audio/files',
    batch_size = 4,
    grad_accum_every = 8,         # effective batch size of 32
    data_max_length = 320 * 32,
    data_max_length_seconds = 2,  # train on 2 second audio
    num_train_steps = 10000
).cuda()

+24 −0
Original line number Diff line number Diff line
@@ -117,6 +117,7 @@ class SoundStreamTrainer(nn.Module):
        num_train_steps,
        batch_size,
        data_max_length = None,
        data_max_length_seconds = None,
        folder,
        lr = 2e-4,
        grad_accum_every = 4,
@@ -173,6 +174,11 @@ class SoundStreamTrainer(nn.Module):

        # create dataset

        assert not (exists(data_max_length) and exists(data_max_length_seconds))

        if exists(data_max_length_seconds):
            data_max_length = data_max_length_seconds * soundstream.target_sample_hz

        self.ds = SoundDataset(
            folder,
            max_length = data_max_length,
@@ -482,6 +488,7 @@ class SemanticTransformerTrainer(nn.Module):
        audio_conditioner: Optional[AudioConditionerBase] = None,
        dataset: Optional[Dataset] = None,
        data_max_length = None,
        data_max_length_seconds = None,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
@@ -528,6 +535,11 @@ class SemanticTransformerTrainer(nn.Module):
        if not exists(self.ds):
            assert exists(folder), 'folder must be passed in, if not passing in a custom dataset for text conditioned audio synthesis training'

            assert not (exists(data_max_length) and exists(data_max_length_seconds))

            if exists(data_max_length_seconds):
                data_max_length = data_max_length_seconds * wav2vec.target_sample_hz

            self.ds = SoundDataset(
                folder,
                max_length = data_max_length,
@@ -711,6 +723,7 @@ class CoarseTransformerTrainer(nn.Module):
        dataset: Optional[Dataset] = None,
        ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_soundstream', 'text'),
        data_max_length = None,
        data_max_length_seconds = None,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
@@ -760,6 +773,11 @@ class CoarseTransformerTrainer(nn.Module):
        if not exists(self.ds):
            assert exists(folder), 'folder must be passed in, if not passing in a custom dataset for text conditioned audio synthesis training'

            assert not (exists(data_max_length) and exists(data_max_length_seconds))

            if exists(data_max_length_seconds):
                data_max_length = tuple(data_max_length_seconds * hz for hz in (wav2vec.target_sample_hz, soundstream.target_sample_hz))

            self.ds = SoundDataset(
                folder,
                max_length = data_max_length,
@@ -947,6 +965,7 @@ class FineTransformerTrainer(nn.Module):
        audio_conditioner: Optional[AudioConditionerBase] = None,
        dataset: Optional[Dataset] = None,
        data_max_length = None,
        data_max_length_seconds = None,
        dataset_normalize = False,
        folder = None,
        lr = 3e-4,
@@ -995,6 +1014,11 @@ class FineTransformerTrainer(nn.Module):
        if not exists(self.ds):
            assert exists(folder), 'folder must be passed in, if not passing in a custom dataset for text conditioned audio synthesis training'

            assert not (exists(data_max_length) and exists(data_max_length_seconds))

            if exists(data_max_length_seconds):
                data_max_length = data_max_length_seconds * soundstream.target_sample_hz

            self.ds = SoundDataset(
                folder,
                max_length = data_max_length,
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.14.2',
  version = '0.14.3',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',