Commit 6ed3aad4 authored by Phil Wang's avatar Phil Wang
Browse files

incorporate all of @turian feedback

parent be547023
Loading
Loading
Loading
Loading
+1 −3
Original line number Diff line number Diff line
@@ -37,7 +37,6 @@ class SoundDataset(Dataset):
        self,
        folder,
        exts = ['flac', 'wav'],
        normalize = False,
        max_length: OptionalIntOrTupleInt = None,
        target_sample_hz: OptionalIntOrTupleInt = None,
        seq_len_multiple_of: OptionalIntOrTupleInt = None
@@ -50,7 +49,6 @@ class SoundDataset(Dataset):
        assert len(files) > 0, 'no sound files found'

        self.files = files
        self.normalize = normalize

        self.target_sample_hz = cast_tuple(target_sample_hz)
        num_outputs = len(self.target_sample_hz)
@@ -66,7 +64,7 @@ class SoundDataset(Dataset):
    def __getitem__(self, idx):
        file = self.files[idx]

        data, sample_hz = torchaudio.load(file, normalize = self.normalize)
        data, sample_hz = torchaudio.load(file)

        assert data.numel() > 0, f'one of your audio file ({file}) is empty. please remove it from your folder'

+5 −6
Original line number Diff line number Diff line
@@ -168,7 +168,7 @@ class ComplexSTFTDiscriminator(nn.Module):
        n_fft = 1024,
        hop_length = 256,
        win_length = 1024,
        normalized = False
        stft_normalized = False
    ):
        super().__init__()
        self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3)
@@ -188,7 +188,7 @@ class ComplexSTFTDiscriminator(nn.Module):

        # stft settings

        self.normalized = normalized
        self.stft_normalized = stft_normalized

        self.n_fft = n_fft
        self.hop_length = hop_length
@@ -210,7 +210,7 @@ class ComplexSTFTDiscriminator(nn.Module):
            self.n_fft,
            hop_length = self.hop_length,
            win_length = self.win_length,
            normalized = self.normalized,
            normalized = self.stft_normalized,
            return_complex = True
        )

@@ -358,7 +358,6 @@ class SoundStream(nn.Module):
        multi_spectral_window_powers_of_two = tuple(range(6, 12)),
        multi_spectral_n_ffts = 512,
        multi_spectral_n_mels = 64,
        multi_spectral_normalized = False,
        recon_loss_weight = 1.,
        multi_spectral_recon_loss_weight = 1.,
        adversarial_loss_weight = 1.,
@@ -447,7 +446,7 @@ class SoundStream(nn.Module):
        self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))])

        self.stft_discriminator = ComplexSTFTDiscriminator(
            normalized = stft_normalized
            stft_normalized = stft_normalized
        )

        # multi spectral reconstruction
@@ -473,7 +472,7 @@ class SoundStream(nn.Module):
                win_length = win_length,
                hop_length = win_length // 4,
                n_mels = n_mels,
                normalized = multi_spectral_normalized
                normalized = stft_normalized
            )

            self.mel_spec_transforms.append(melspec_transform)
+3 −9
Original line number Diff line number Diff line
@@ -117,7 +117,6 @@ class SoundStreamTrainer(nn.Module):
        batch_size,
        data_max_length = None,
        folder,
        dataset_normalize = False,
        lr = 2e-4,
        grad_accum_every = 4,
        wd = 0.,
@@ -168,8 +167,7 @@ class SoundStreamTrainer(nn.Module):
            folder,
            max_length = data_max_length,
            target_sample_hz = soundstream.target_sample_hz,
            seq_len_multiple_of = soundstream.seq_len_multiple_of,
            normalize = dataset_normalize
            seq_len_multiple_of = soundstream.seq_len_multiple_of
        )

        # split for validation
@@ -437,7 +435,6 @@ class SemanticTransformerTrainer(nn.Module):
        audio_conditioner: Optional[AudioConditionerBase] = None,
        dataset: Optional[Dataset] = None,
        data_max_length = None,
        dataset_normalize = False,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
@@ -487,8 +484,7 @@ class SemanticTransformerTrainer(nn.Module):
                folder,
                max_length = data_max_length,
                target_sample_hz = wav2vec.target_sample_hz,
                seq_len_multiple_of = wav2vec.seq_len_multiple_of,
                normalize = dataset_normalize
                seq_len_multiple_of = wav2vec.seq_len_multiple_of
            )

        self.ds_fields = None
@@ -668,7 +664,6 @@ class CoarseTransformerTrainer(nn.Module):
        dataset: Optional[Dataset] = None,
        ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_soundstream', 'text'),
        data_max_length = None,
        dataset_normalize = False,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
@@ -724,8 +719,7 @@ class CoarseTransformerTrainer(nn.Module):
                    wav2vec.target_sample_hz,
                    soundstream.target_sample_hz
                ), # need 2 waves resampled differently here
                seq_len_multiple_of = soundstream.seq_len_multiple_of,
                normalize = dataset_normalize
                seq_len_multiple_of = soundstream.seq_len_multiple_of
            )

        self.ds_fields = ds_fields
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.11.14',
  version = '0.11.15',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',