Commit 2e654ba8 authored by Phil Wang's avatar Phil Wang
Browse files

add ability to condition on joint text-audio embeddings from mulan, over at musiclm-pytorch

parent 04bed934
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
from audiolm_pytorch.audiolm_pytorch import AudioLM
from audiolm_pytorch.soundstream import SoundStream

from audiolm_pytorch.audiolm_pytorch import SemanticBase, CoarseBase, FineBase
from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransformer, FineTransformer
from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper, SemanticTransformerWrapper

+40 −14
Original line number Diff line number Diff line
@@ -21,6 +21,7 @@ from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
from torchaudio.functional import resample

from audiolm_pytorch.soundstream import SoundStream
from audiolm_pytorch.utils import AudioConditionerBase

from tqdm import tqdm

@@ -424,21 +425,10 @@ class Transformer(nn.Module):

        return self.norm(x)

# bases

class SemanticBase(nn.Module):
    pass

class CoarseBase(nn.Module):
    pass

class FineBase(nn.Module):
    pass

# the three hierarchical transformers

@beartype
class SemanticTransformer(SemanticBase):
class SemanticTransformer(nn.Module):
    def __init__(
        self,
        *,
@@ -553,7 +543,7 @@ class SemanticTransformer(SemanticBase):
        return self.to_logits(tokens)

@beartype
class CoarseTransformer(CoarseBase):
class CoarseTransformer(nn.Module):
    def __init__(
        self,
        *,
@@ -721,7 +711,7 @@ class CoarseTransformer(CoarseBase):

        return semantic_logits, coarse_logits

class FineTransformer(FineBase):
class FineTransformer(nn.Module):
    def __init__(
        self,
        *,
@@ -917,6 +907,7 @@ class SemanticTransformerWrapper(nn.Module):
        *,
        transformer: SemanticTransformer,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        pad_id = -1,
        unique_consecutive = True,
        mask_prob = 0.15
@@ -924,6 +915,8 @@ class SemanticTransformerWrapper(nn.Module):
        super().__init__()
        self.wav2vec = wav2vec
        self.transformer = transformer
        self.audio_conditioner = audio_conditioner

        assert not exists(self.wav2vec) or self.wav2vec.codebook_size == transformer.num_semantic_tokens, f'num_semantic_tokens on SemanticTransformer must be set to {self.wav2vec.codebook_size}'

        self.unique_consecutive = unique_consecutive
@@ -969,6 +962,12 @@ class SemanticTransformerWrapper(nn.Module):
        if self.unique_consecutive:
            ids = batch_unique_consecutive(ids, pad_value = self.pad_id)

        # derive joint audio-text embeddings if needed

        if exists(self.audio_conditioner) and exists(prime_wave):
            assert not exists(text) and not exists(text_embeds)
            text_embeds = self.audio_conditioner(prime_wave)

        # derive text embeddings if needed

        has_text = exists(text) or exists(text_embeds)
@@ -1028,6 +1027,11 @@ class SemanticTransformerWrapper(nn.Module):
    ):
        assert exists(raw_wave) or exists(semantic_token_ids), 'either raw waveform (raw_wave) is given or semantic token ids are given (semantic_token_ids)'

        if exists(self.audio_conditioner):
            assert exists(raw_wave)
            assert not exists(text) and not exists(text_embeds)
            text_embeds = self.audio_conditioner(raw_wave)

        if not exists(semantic_token_ids):
            assert exists(self.wav2vec), 'VQWav2Vec must be be provided if given raw wave for training'
            semantic_token_ids = self.wav2vec(raw_wave, flatten = False)
@@ -1075,6 +1079,7 @@ class CoarseTransformerWrapper(nn.Module):
        transformer: CoarseTransformer,
        soundstream: Optional[SoundStream]  = None,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        pad_id = -1,
        unique_consecutive = True,
        semantic_cross_entropy_loss_weight = 1.,
@@ -1083,6 +1088,7 @@ class CoarseTransformerWrapper(nn.Module):
        super().__init__()
        self.soundstream = soundstream
        self.wav2vec = wav2vec
        self.audio_conditioner = audio_conditioner

        self.transformer = transformer
        self.unique_consecutive = unique_consecutive
@@ -1177,6 +1183,8 @@ class CoarseTransformerWrapper(nn.Module):
        semantic_token_ids = None,
        raw_wave = None,
        raw_wave_for_soundstream = None,
        text = None,
        text_embeds = None,
        coarse_token_ids = None,
        return_loss = False,
        **kwargs
@@ -1188,6 +1196,11 @@ class CoarseTransformerWrapper(nn.Module):

        assert not all(map(exists, (raw_wave, raw_wave_for_soundstream, semantic_token_ids, coarse_token_ids)))

        if exists(self.audio_conditioner):
            assert exists(raw_wave)
            assert not exists(text) and not exists(text_embeds)
            text_embeds = self.audio_conditioner(raw_wave) # technically audio embeds, but shared text-audio joint embedding space for mulan

        if not exists(semantic_token_ids):
            assert exists(self.wav2vec), 'VQWav2Vec must be be provided if given raw wave for training'
            semantic_token_ids = self.wav2vec(raw_wave, flatten = False)
@@ -1226,6 +1239,8 @@ class CoarseTransformerWrapper(nn.Module):
            semantic_token_ids = semantic_token_ids,
            coarse_token_ids = coarse_token_ids,
            self_attn_mask = self_attn_mask,
            text = text,
            text_embeds = text_embeds,
            **kwargs
        )

@@ -1271,6 +1286,7 @@ class FineTransformerWrapper(nn.Module):
        *,
        transformer: FineTransformer,
        soundstream: Optional[SoundStream] = None,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        coarse_cross_entropy_loss_weight = 1.,
        pad_id = -1,
        mask_prob = 0.15
@@ -1278,6 +1294,7 @@ class FineTransformerWrapper(nn.Module):
        super().__init__()
        self.soundstream = soundstream
        self.transformer = transformer
        self.audio_conditioner = audio_conditioner

        self.num_fine_quantizers = transformer.num_fine_quantizers
        self.num_coarse_quantizers = transformer.num_coarse_quantizers
@@ -1391,6 +1408,8 @@ class FineTransformerWrapper(nn.Module):
        self,
        *,
        raw_wave = None,
        text = None,
        text_embeds = None,
        token_ids = None,
        coarse_token_ids = None,
        fine_token_ids = None,
@@ -1399,6 +1418,11 @@ class FineTransformerWrapper(nn.Module):
    ):
        assert exists(raw_wave) ^ (exists(token_ids) ^ (exists(coarse_token_ids) and exists(fine_token_ids))), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'

        if exists(self.audio_conditioner):
            assert exists(raw_wave)
            assert not exists(text) and not exists(text_embeds)
            text_embeds = self.audio_conditioner(raw_wave) # technically audio embeds, but shared text-audio joint embedding space for mulan

        if exists(raw_wave):
            assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'

@@ -1432,6 +1456,8 @@ class FineTransformerWrapper(nn.Module):
            coarse_token_ids = coarse_token_ids,
            fine_token_ids = fine_token_ids,
            self_attn_mask = self_attn_mask,
            text = text,
            text_embeds = text_embeds,
            **kwargs
        )

+16 −9
Original line number Diff line number Diff line
@@ -25,9 +25,6 @@ from ema_pytorch import EMA
from audiolm_pytorch.soundstream import SoundStream

from audiolm_pytorch.audiolm_pytorch import (
    SemanticBase,
    CoarseBase,
    FineBase,
    SemanticTransformer,
    SemanticTransformerWrapper,
    CoarseTransformer,
@@ -39,6 +36,7 @@ from audiolm_pytorch.audiolm_pytorch import (
)

from audiolm_pytorch.data import SoundDataset, get_dataloader
from audiolm_pytorch.utils import AudioConditionerBase

from accelerate import Accelerator

@@ -429,10 +427,11 @@ class SemanticTransformerTrainer(nn.Module):
    def __init__(
        self,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
        transformer: SemanticBase,
        transformer: SemanticTransformer,
        *,
        num_train_steps,
        batch_size,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        dataset: Optional[Dataset] = None,
        data_max_length = None,
        folder = None,
@@ -452,10 +451,12 @@ class SemanticTransformerTrainer(nn.Module):

        self.wav2vec = wav2vec
        self.transformer = transformer
        self.audio_conditioner = audio_conditioner

        self.train_wrapper = SemanticTransformerWrapper(
            wav2vec = wav2vec,
            transformer = transformer
            transformer = transformer,
            audio_conditioner = audio_conditioner
        )

        self.register_buffer('steps', torch.Tensor([0]))
@@ -652,12 +653,13 @@ class SemanticTransformerTrainer(nn.Module):
class CoarseTransformerTrainer(nn.Module):
    def __init__(
        self,
        transformer: CoarseBase,
        transformer: CoarseTransformer,
        soundstream: SoundStream,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
        *,
        num_train_steps,
        batch_size,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        dataset: Optional[Dataset] = None,
        ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_soundstream', 'text'),
        data_max_length = None,
@@ -679,11 +681,13 @@ class CoarseTransformerTrainer(nn.Module):
        self.transformer = transformer
        self.soundstream = soundstream
        self.wav2vec = wav2vec
        self.audio_conditioner = audio_conditioner

        self.train_wrapper = CoarseTransformerWrapper(
            soundstream = soundstream,
            wav2vec = wav2vec,
            transformer = transformer
            transformer = transformer,
            audio_conditioner = audio_conditioner
        )

        self.register_buffer('steps', torch.Tensor([0]))
@@ -887,11 +891,12 @@ class CoarseTransformerTrainer(nn.Module):
class FineTransformerTrainer(nn.Module):
    def __init__(
        self,
        transformer: FineBase,
        transformer: FineTransformer,
        soundstream: SoundStream,
        *,
        num_train_steps,
        batch_size,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        dataset: Optional[Dataset] = None,
        data_max_length = None,
        folder = None,
@@ -911,10 +916,12 @@ class FineTransformerTrainer(nn.Module):

        self.transformer = transformer
        self.soundstream = soundstream
        self.audio_conditioner = audio_conditioner

        self.train_wrapper = FineTransformerWrapper(
            soundstream = soundstream,
            transformer = transformer
            transformer = transformer,
            audio_conditioner = audio_conditioner
        )

        self.register_buffer('steps', torch.Tensor([0]))
+13 −1
Original line number Diff line number Diff line
from torch import nn

# functions

def round_down_nearest_multiple(num, divisor):
    return num // divisor * divisor

def curtail_to_multiple(t, mult):
    data_len = t.shape[-1]
    return t[..., :(data_len // mult * mult)]
    return t[..., :round_down_nearest_multiple(data_len, mult)]

# base class

class AudioConditionerBase(nn.Module):
    pass
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.8.3',
  version = '0.9.0',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',