Commit 49cc3923 authored by Phil Wang's avatar Phil Wang
Browse files

convenience keyword argument audio_text_condition, just have researchers set...

convenience keyword argument audio_text_condition, just have researchers set that to True for musiclm training
parent b3a763c2
Loading
Loading
Loading
Loading
+29 −2
Original line number Diff line number Diff line
@@ -439,7 +439,9 @@ class SemanticTransformer(nn.Module):
        attn_dropout = 0.,
        ff_dropout = 0.,
        t5_name = DEFAULT_T5_NAME,
        cond_dim = None,
        has_condition = False,
        audio_text_condition = False,
        cond_as_self_attn_prefix = False,
        cond_drop_prob = 0.5,
        grad_shrink_alpha = 0.1,
@@ -449,6 +451,10 @@ class SemanticTransformer(nn.Module):
        super().__init__()
        self.num_semantic_tokens = num_semantic_tokens

        if audio_text_condition:
            has_condition = True
            cond_dim = default(cond_dim, dim)

        self.has_condition = has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob
@@ -459,7 +465,7 @@ class SemanticTransformer(nn.Module):
        self.eos_id = num_semantic_tokens
        self.pad_id = pad_id

        text_dim = get_encoded_dim(t5_name)
        text_dim = default(cond_dim, get_encoded_dim(t5_name))
        self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity()

        self.transformer = Transformer(
@@ -557,6 +563,8 @@ class CoarseTransformer(nn.Module):
        ff_dropout = 0.,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_dim = None,
        audio_text_condition = False,
        cond_as_self_attn_prefix = False,
        cond_drop_prob = 0.5,
        grad_shrink_alpha = 0.1,
@@ -566,6 +574,10 @@ class CoarseTransformer(nn.Module):
        super().__init__()
        self.num_semantic_tokens = num_semantic_tokens

        if audio_text_condition:
            has_condition = True
            cond_dim = default(cond_dim, dim)

        self.has_condition = has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob
@@ -725,6 +737,8 @@ class FineTransformer(nn.Module):
        ff_dropout = 0.,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_dim = None,
        audio_text_condition = False,
        cond_as_self_attn_prefix = False,
        cond_drop_prob = 0.5,
        grad_shrink_alpha = 0.1,
@@ -732,6 +746,11 @@ class FineTransformer(nn.Module):
        **kwargs
    ):
        super().__init__()

        if audio_text_condition:
            has_condition = True
            cond_dim = default(cond_dim, dim)

        self.has_condition = has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob
@@ -917,6 +936,8 @@ class SemanticTransformerWrapper(nn.Module):
        self.transformer = transformer
        self.audio_conditioner = audio_conditioner

        assert not (exists(audio_conditioner) and not transformer.has_condition), 'if conditioning on audio embeddings from mulan, transformer has_condition must be set to True'

        assert not exists(self.wav2vec) or self.wav2vec.codebook_size == transformer.num_semantic_tokens, f'num_semantic_tokens on SemanticTransformer must be set to {self.wav2vec.codebook_size}'

        self.unique_consecutive = unique_consecutive
@@ -1088,9 +1109,12 @@ class CoarseTransformerWrapper(nn.Module):
        super().__init__()
        self.soundstream = soundstream
        self.wav2vec = wav2vec
        self.audio_conditioner = audio_conditioner

        self.transformer = transformer
        self.audio_conditioner = audio_conditioner

        assert not (exists(audio_conditioner) and not transformer.has_condition), 'if conditioning on audio embeddings from mulan, transformer has_condition must be set to True'

        self.unique_consecutive = unique_consecutive
        self.pad_id = pad_id

@@ -1293,9 +1317,12 @@ class FineTransformerWrapper(nn.Module):
    ):
        super().__init__()
        self.soundstream = soundstream

        self.transformer = transformer
        self.audio_conditioner = audio_conditioner

        assert not (exists(audio_conditioner) and not transformer.has_condition), 'if conditioning on audio embeddings from mulan, transformer has_condition must be set to True'

        self.num_fine_quantizers = transformer.num_fine_quantizers
        self.num_coarse_quantizers = transformer.num_coarse_quantizers
        self.eos_id = transformer.eos_id
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.9.2',
  version = '0.9.3',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',