Commit 60e4cd84 authored by Phil Wang's avatar Phil Wang
Browse files

not all the transformer may need conditioning, perhaps only semantic is...

not all the transformer may need conditioning, perhaps only semantic is conditioned, the rest just upresolutes
parent 2a785158
Loading
Loading
Loading
Loading
+10 −3
Original line number Diff line number Diff line
@@ -1393,6 +1393,11 @@ class AudioLM(nn.Module):
        assert coarse_transformer.codebook_size == fine_transformer.codebook_size
        assert coarse_transformer.num_coarse_quantizers == fine_transformer.num_coarse_quantizers

        self.semantic_has_condition = semantic_transformer.has_condition
        self.coarse_has_condition = coarse_transformer.has_condition
        self.fine_has_condition = fine_transformer.has_condition
        self.needs_text = any([self.semantic_has_condition, self.coarse_has_condition, self.fine_has_condition])

        self.semantic = SemanticTransformerWrapper(
            wav2vec = wav2vec,
            transformer = semantic_transformer,
@@ -1427,18 +1432,20 @@ class AudioLM(nn.Module):
        return_coarse_generated_wave = False,
        mask_out_generated_fine_tokens = False
    ):
        assert not (self.needs_text and not exists(text)), 'text needs to be passed in if one of the transformer requires conditioning'

        if exists(prime_wave):
            prime_wave = prime_wave.to(self.device)

        semantic_token_ids = self.semantic.generate(
            text = text,
            text = text if self.semantic_has_condition else None,
            batch_size = batch_size,
            prime_wave = prime_wave,
            max_length = max_length
        )

        coarse_token_ids_or_recon_wave = self.coarse.generate(
            text = text,
            text = text if self.coarse_has_condition else None,
            semantic_token_ids = semantic_token_ids,
            reconstruct_wave = return_coarse_generated_wave
        )
@@ -1447,7 +1454,7 @@ class AudioLM(nn.Module):
            return coarse_token_ids_or_recon_wave

        generated_wave = self.fine.generate(
            text = text,
            text = text if self.fine_has_condition else None,
            coarse_token_ids = coarse_token_ids_or_recon_wave,
            reconstruct_wave = True,
            mask_out_generated_fine_tokens = mask_out_generated_fine_tokens
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.1.17',
  version = '0.1.18',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',