Commit 37d9efab authored by Phil Wang's avatar Phil Wang
Browse files

remove final blocker around training code for coarse transformer, which...

remove final blocker around training code for coarse transformer, which requires wave to have different resampling freqs
parent f760d083
Loading
Loading
Loading
Loading
+7 −3
Original line number Diff line number Diff line
@@ -954,13 +954,17 @@ class CoarseTransformerWrapper(nn.Module):
        *,
        semantic_token_ids = None,
        raw_wave = None,
        raw_wave_for_soundstream = None,
        coarse_token_ids = None,
        return_loss = False,
        **kwargs
    ):
        assert exists(raw_wave) or exists(semantic_token_ids), 'either raw waveform (raw_wave) is given or semantic token ids are given (semantic_token_ids)'
        assert exists(raw_wave) or exists(coarse_token_ids), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
        assert not all(map(exists, (raw_wave, semantic_token_ids, coarse_token_ids)))

        raw_wave_for_soundstream = default(raw_wave_for_soundstream, raw_wave)
        assert exists(raw_wave_for_soundstream) or exists(coarse_token_ids), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'

        assert not all(map(exists, (raw_wave, raw_wave_for_soundstream, semantic_token_ids, coarse_token_ids)))

        if not exists(semantic_token_ids):
            assert exists(self.wav2vec), 'VQWav2Vec must be be provided if given raw wave for training'
@@ -971,7 +975,7 @@ class CoarseTransformerWrapper(nn.Module):

            with torch.no_grad():
                self.soundstream.eval()
                _, indices, _ = self.soundstream(raw_wave, return_encoded = True)
                _, indices, _ = self.soundstream(raw_wave_for_soundstream, return_encoded = True)
                coarse_token_ids, _ = indices[..., :self.num_coarse_quantizers], indices[..., self.num_coarse_quantizers:]

        semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)')
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.48',
  version = '0.0.49',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',