Commit 1831daed authored by Phil Wang's avatar Phil Wang
Browse files

allow maximal flexibility in normalizing, either when loading data, or doing...

allow maximal flexibility in normalizing, either when loading data, or doing spectrogram transform https://github.com/lucidrains/audiolm-pytorch/issues/82
parent 0863e897
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ class SoundDataset(Dataset):
        self,
        folder,
        exts = ['flac', 'wav'],
        normalize = False,
        max_length: OptionalIntOrTupleInt = None,
        target_sample_hz: OptionalIntOrTupleInt = None,
        seq_len_multiple_of: OptionalIntOrTupleInt = None
@@ -49,6 +50,7 @@ class SoundDataset(Dataset):
        assert len(files) > 0, 'no sound files found'

        self.files = files
        self.normalize = normalize

        self.target_sample_hz = cast_tuple(target_sample_hz)
        num_outputs = len(self.target_sample_hz)
@@ -64,7 +66,7 @@ class SoundDataset(Dataset):
    def __getitem__(self, idx):
        file = self.files[idx]

        data, sample_hz = torchaudio.load(file)
        data, sample_hz = torchaudio.load(file, normalize = self.normalize)

        assert data.numel() > 0, f'one of your audio file ({file}) is empty. please remove it from your folder'

+11 −2
Original line number Diff line number Diff line
@@ -167,7 +167,8 @@ class ComplexSTFTDiscriminator(nn.Module):
        input_channels = 1,
        n_fft = 1024,
        hop_length = 256,
        win_length = 1024
        win_length = 1024,
        normalized = False
    ):
        super().__init__()
        self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3)
@@ -187,6 +188,8 @@ class ComplexSTFTDiscriminator(nn.Module):

        # stft settings

        self.normalized = normalized

        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
@@ -207,6 +210,7 @@ class ComplexSTFTDiscriminator(nn.Module):
            self.n_fft,
            hop_length = self.hop_length,
            win_length = self.win_length,
            normalized = self.normalized,
            return_complex = True
        )

@@ -348,11 +352,13 @@ class SoundStream(nn.Module):
        rq_ema_decay = 0.95,
        input_channels = 1,
        discr_multi_scales = (1, 0.5, 0.25),
        stft_normalized = False,
        enc_cycle_dilations = (1, 3, 9),
        dec_cycle_dilations = (1, 3, 9),
        multi_spectral_window_powers_of_two = tuple(range(6, 12)),
        multi_spectral_n_ffts = 512,
        multi_spectral_n_mels = 64,
        multi_spectral_normalized = False,
        recon_loss_weight = 1.,
        multi_spectral_recon_loss_weight = 1.,
        adversarial_loss_weight = 1.,
@@ -440,7 +446,9 @@ class SoundStream(nn.Module):
        self.discr_multi_scales = discr_multi_scales
        self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))])

        self.stft_discriminator = ComplexSTFTDiscriminator()
        self.stft_discriminator = ComplexSTFTDiscriminator(
            normalized = stft_normalized
        )

        # multi spectral reconstruction

@@ -465,6 +473,7 @@ class SoundStream(nn.Module):
                win_length = win_length,
                hop_length = win_length // 4,
                n_mels = n_mels,
                normalized = multi_spectral_normalized
            )

            self.mel_spec_transforms.append(melspec_transform)
+12 −4
Original line number Diff line number Diff line
@@ -117,6 +117,7 @@ class SoundStreamTrainer(nn.Module):
        batch_size,
        data_max_length = None,
        folder,
        dataset_normalize = False,
        lr = 2e-4,
        grad_accum_every = 4,
        wd = 0.,
@@ -167,7 +168,8 @@ class SoundStreamTrainer(nn.Module):
            folder,
            max_length = data_max_length,
            target_sample_hz = soundstream.target_sample_hz,
            seq_len_multiple_of = soundstream.seq_len_multiple_of
            seq_len_multiple_of = soundstream.seq_len_multiple_of,
            normalize = dataset_normalize
        )

        # split for validation
@@ -435,6 +437,7 @@ class SemanticTransformerTrainer(nn.Module):
        audio_conditioner: Optional[AudioConditionerBase] = None,
        dataset: Optional[Dataset] = None,
        data_max_length = None,
        dataset_normalize = False,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
@@ -484,7 +487,8 @@ class SemanticTransformerTrainer(nn.Module):
                folder,
                max_length = data_max_length,
                target_sample_hz = wav2vec.target_sample_hz,
                seq_len_multiple_of = wav2vec.seq_len_multiple_of
                seq_len_multiple_of = wav2vec.seq_len_multiple_of,
                normalize = dataset_normalize
            )

        self.ds_fields = None
@@ -664,6 +668,7 @@ class CoarseTransformerTrainer(nn.Module):
        dataset: Optional[Dataset] = None,
        ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_soundstream', 'text'),
        data_max_length = None,
        dataset_normalize = False,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
@@ -719,7 +724,8 @@ class CoarseTransformerTrainer(nn.Module):
                    wav2vec.target_sample_hz,
                    soundstream.target_sample_hz
                ), # need 2 waves resampled differently here
                seq_len_multiple_of = soundstream.seq_len_multiple_of
                seq_len_multiple_of = soundstream.seq_len_multiple_of,
                normalize = dataset_normalize
            )

        self.ds_fields = ds_fields
@@ -900,6 +906,7 @@ class FineTransformerTrainer(nn.Module):
        audio_conditioner: Optional[AudioConditionerBase] = None,
        dataset: Optional[Dataset] = None,
        data_max_length = None,
        dataset_normalize = False,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
@@ -950,7 +957,8 @@ class FineTransformerTrainer(nn.Module):
                folder,
                max_length = data_max_length,
                target_sample_hz = soundstream.target_sample_hz,
                seq_len_multiple_of = soundstream.seq_len_multiple_of
                seq_len_multiple_of = soundstream.seq_len_multiple_of,
                normalize = dataset_normalize
            )

        self.ds_fields = None
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.11.12',
  version = '0.11.14',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',