Commit c17ee7d3 authored by Phil Wang's avatar Phil Wang
Browse files

add cross attention layers as well as setup t5 and some conditioning logic,...

add cross attention layers as well as setup t5 and some conditioning logic, for helping researchers explore TTS in this setting
parent 26dfc80f
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -48,7 +48,7 @@ loss.backward()

- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research

- <a href="https://huggingface.co/">🤗 Huggingface</a> for their amazing accelerate library
- <a href="https://huggingface.co/">🤗 Huggingface</a> for their amazing accelerate and transformers libraries

- <a href="https://ai.facebook.com/">MetaAI</a> for <a href="https://github.com/facebookresearch/fairseq">Fairseq</a> and the liberal license

@@ -58,6 +58,7 @@ loss.backward()

- [x] complete CoarseTransformer
- [x] use fairseq vq-wav2vec for embeddings
- [x] add conditioning

- [ ] incorporate ability to use hubert intermediate features as semantic tokens, recommended by <a href="https://github.com/lucidrains/audiolm-pytorch/discussions/13">eonglints</a>
- [ ] complete full training code for soundstream, taking care of discriminator training
@@ -69,7 +70,8 @@ loss.backward()
- [ ] offer option to weight tie coarse, fine, and semantic embeddings across the 3 hierarchical transformers
- [ ] DRY a little at the end
- [ ] figure out how to suppress logging in fairseq
- [ ] test with speech synthesis for starters, add conditioning + classifier free guidance as well
- [ ] test with speech synthesis for starters
- [ ] add classifier free guidance

## Citations

+127 −22
Original line number Diff line number Diff line
@@ -14,11 +14,16 @@ from vector_quantize_pytorch import ResidualVQ
from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans

from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

# helper functions

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def ceil_div(numer, denom):
    return (numer + denom - 1) // denom

@@ -471,24 +476,53 @@ class Attention(nn.Module):
    def __init__(
        self,
        dim,
        causal = False,
        dim_head = 64,
        heads = 8
        dim_context = None,
        heads = 8,
        norm_context = False,
        num_null_kv = 0
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.causal = causal
        inner_dim = dim_head * heads

        dim_context = default(dim_context, dim)

        self.norm = nn.LayerNorm(dim)
        self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()

        self.num_null_kv = num_null_kv
        self.null_kv = nn.Parameter(torch.randn(2, num_null_kv, dim_head))

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
        self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x, attn_bias = None):
    def forward(
        self,
        x,
        context = None,
        mask = None,
        attn_bias = None
    ):
        b = x.shape[0]

        if exists(context):
            context = self.context_norm(context)

        kv_input = default(context, x)

        x = self.norm(x)

        q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)
        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

        if self.num_null_kv > 0:
            null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b = b).unbind(dim = 0)
            k = torch.cat((null_k, k), dim = -2)
            v = torch.cat((null_v, v), dim = -2)

        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

@@ -497,8 +531,15 @@ class Attention(nn.Module):
        sim = einsum('b h i d, b j d -> b h i j', q, k)

        if exists(attn_bias):
            attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value = 0.)
            sim = sim + attn_bias

        if exists(mask):
            mask = F.pad(mask, (self.num_null_kv, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
@@ -518,6 +559,8 @@ class Transformer(nn.Module):
        *,
        dim,
        depth,
        dim_context = None,
        cross_attend = False,
        **kwargs
    ):
        super().__init__()
@@ -527,19 +570,31 @@ class Transformer(nn.Module):

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, **kwargs),
                Attention(dim = dim, causal = True, **kwargs),
                Attention(dim = dim, dim_context = dim_context, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None,
                FeedForward(dim = dim)
            ]))

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
    def forward(
        self,
        x,
        context = None,
        context_mask = None
    ):
        n, device = x.shape[1], x.device

        rel_pos_bias = self.rel_pos_bias(n, n, device = device)

        for attn, ff in self.layers:
        for attn, cross_attn, ff in self.layers:
            x = attn(x, attn_bias = rel_pos_bias) + x

            if exists(cross_attn):
                assert exists(context)

                x = cross_attn(x, context = context, mask = context_mask)

            x = ff(x) + x

        return self.norm(x)
@@ -552,16 +607,21 @@ class SemanticTransformer(nn.Module):
        *,
        num_semantic_tokens,
        dim,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        **kwargs
    ):
        super().__init__()
        self.has_condition = has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)

        self.start_token = nn.Parameter(torch.randn(dim))

        self.semantic_embedding = nn.Embedding(num_semantic_tokens, dim)

        self.wav2vec = wav2vec
        self.transformer = Transformer(dim = dim, **kwargs)
        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, **kwargs)
        self.to_logits = nn.Linear(dim, num_semantic_tokens)

    def forward(
@@ -569,8 +629,20 @@ class SemanticTransformer(nn.Module):
        *,
        raw_wave = None,
        ids = None,
        return_loss = False
        return_loss = False,
        text = None,
        text_embed = None
    ):
        device = next(self.parameters()).device

        has_text = exists(text) or exists(text_embed)
        assert not (self.has_condition ^ has_text)

        if not exists(text_embed):
            with torch.no_grad():
                text_embeds = self.embed_text(text, output_device = device)
                text_mask = torch.any(text_embeds != 0, dim = -1)

        assert exists(raw_wave) ^ exists(ids)

        if not exists(ids):
@@ -586,7 +658,7 @@ class SemanticTransformer(nn.Module):

        tokens = torch.cat((start_tokens, tokens), dim = 1)

        tokens = self.transformer(tokens)
        tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask)
        logits = self.to_logits(tokens)

        if not return_loss:
@@ -607,17 +679,22 @@ class CoarseTransformer(nn.Module):
        codebook_size,
        num_coarse_quantizers,
        dim,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        **kwargs
    ):
        super().__init__()
        self.has_condition = has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)

        self.start_token = nn.Parameter(torch.randn(dim))

        self.semantic_embedding = nn.Embedding(num_semantic_tokens, dim)
        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size, dim)

        self.wav2vec = wav2vec
        self.transformer = Transformer(dim = dim, **kwargs)
        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, **kwargs)

        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers
@@ -630,9 +707,19 @@ class CoarseTransformer(nn.Module):
        *,
        semantic_token_ids,
        coarse_token_ids,
        text = None,
        text_embed = None
    ):
        b, device = semantic_token_ids.shape[0], semantic_token_ids.device

        has_text = exists(text) or exists(text_embed)
        assert not (self.has_condition ^ has_text)

        if not exists(text_embed):
            with torch.no_grad():
                text_embeds = self.embed_text(text, output_device = device)
                text_mask = torch.any(text_embeds != 0, dim = -1)

        coarse_token_ids, semantic_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, semantic_token_ids))

        offsets = self.codebook_size * torch.arange(self.num_coarse_quantizers, device = device)
@@ -649,7 +736,7 @@ class CoarseTransformer(nn.Module):

        tokens = torch.cat((start_tokens, semantic_tokens, coarse_tokens), dim = 1)

        tokens = self.transformer(tokens)
        tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask)

        pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, semantic_seq_len:]

@@ -689,16 +776,20 @@ class FineTransformer(nn.Module):
        num_fine_quantizers,
        codebook_size,
        dim,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        **kwargs
    ):
        super().__init__()
        self.has_condition = has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)

        self.start_token = nn.Parameter(torch.randn(dim))

        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size, dim)
        self.fine_embedding = nn.Embedding(num_fine_quantizers * codebook_size, dim)

        self.transformer = Transformer(dim = dim, **kwargs)
        self.transformer = Transformer(dim = dim, dim_context = get_encoded_dim(t5_name), cross_attend = has_condition, **kwargs)

        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers
@@ -710,10 +801,20 @@ class FineTransformer(nn.Module):
    def forward(
        self,
        coarse_token_ids,
        fine_token_ids
        fine_token_ids,
        text = None,
        text_embed = None
    ):
        device = coarse_token_ids.device

        has_text = exists(text) or exists(text_embed)
        assert not (self.has_condition ^ has_text)

        if not exists(text_embed):
            with torch.no_grad():
                text_embeds = self.embed_text(text, output_device = device)
                text_mask = torch.any(text_embeds != 0, dim = -1)

        coarse_token_ids, fine_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, fine_token_ids))

        b, n = coarse_token_ids.shape
@@ -735,7 +836,7 @@ class FineTransformer(nn.Module):

        tokens = torch.cat((start_tokens, coarse_tokens, fine_tokens), dim = 1)

        tokens = self.transformer(tokens)
        tokens = self.transformer(tokens, context = text_embeds, context_mask = text_mask)

        pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, n:]

@@ -794,7 +895,8 @@ class FineTransformerWrapper(nn.Module):
        raw_wave = None,
        coarse_token_ids = None,
        fine_token_ids = None,
        return_loss = False
        return_loss = False,
        **kwargs
    ):
        assert exists(raw_wave) ^ (exists(coarse_token_ids) and exists(fine_token_ids)), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'

@@ -815,7 +917,8 @@ class FineTransformerWrapper(nn.Module):

        coarse_logits, fine_logits = self.transformer(
            coarse_token_ids = coarse_token_ids,
            fine_token_ids = fine_token_ids
            fine_token_ids = fine_token_ids,
            **kwargs
        )

        if not return_loss:
@@ -859,7 +962,8 @@ class CoarseTransformerWrapper(nn.Module):
        semantic_token_ids = None,
        raw_wave = None,
        coarse_token_ids = None,
        return_loss = False
        return_loss = False,
        **kwargs
    ):
        assert exists(raw_wave) or exists(semantic_token_ids), 'either raw waveform (raw_wave) is given or semantic token ids are given (semantic_token_ids)'
        assert exists(raw_wave) or exists(coarse_token_ids), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'
@@ -886,7 +990,8 @@ class CoarseTransformerWrapper(nn.Module):

        semantic_logits, coarse_logits = self.transformer(
            semantic_token_ids = semantic_token_ids,
            coarse_token_ids = coarse_token_ids
            coarse_token_ids = coarse_token_ids,
            **kwargs
        )

        if not return_loss:

audiolm_pytorch/t5.py

0 → 100644
+103 −0
Original line number Diff line number Diff line
import torch
import transformers
from transformers import T5Tokenizer, T5EncoderModel, T5Config

# less warning messages since only using encoder

transformers.logging.set_verbosity_error()

# helper functions

def exists(val):
    return val is not None

# config

MAX_LENGTH = 256

DEFAULT_T5_NAME = 'google/t5-v1_1-base'

T5_CONFIGS = {}

# singleton globals

def get_tokenizer(name):
    tokenizer = T5Tokenizer.from_pretrained(name)
    return tokenizer

def get_model(name):
    model = T5EncoderModel.from_pretrained(name)
    return model

def get_model_and_tokenizer(name):
    global T5_CONFIGS

    if name not in T5_CONFIGS:
        T5_CONFIGS[name] = dict()

    if "model" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["model"] = get_model(name)

    if "tokenizer" not in T5_CONFIGS[name]:
        T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)

    return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']

def get_encoded_dim(name):
    if name not in T5_CONFIGS:
        config = T5Config.from_pretrained(name)
        T5_CONFIGS[name] = dict(config = config)

    elif "config" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["config"]

    elif "model" in T5_CONFIGS[name]:
        config = T5_CONFIGS[name]["model"].config

    else:
        raise ValueError(f'unknown t5 name {name}')

    return config.d_model

# encoding text

def t5_encode_text(
    texts,
    name = DEFAULT_T5_NAME,
    output_device = None
):
    t5, tokenizer = get_model_and_tokenizer(name)

    if torch.cuda.is_available():
        t5 = t5.cuda()

    device = next(t5.parameters()).device

    encoded = tokenizer.batch_encode_plus(
        texts,
        return_tensors = 'pt',
        padding = 'longest',
        max_length = MAX_LENGTH,
        truncation = True
    )

    input_ids = encoded.input_ids.to(device)
    attn_mask = encoded.attention_mask.to(device)

    t5.eval()

    with torch.no_grad():
        output = t5(input_ids = input_ids, attention_mask = attn_mask)
        encoded_text = output.last_hidden_state.detach()

    attn_mask = attn_mask[..., None].bool()

    if not exists(output_device):
        encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
        return encoded_text

    encoded_text.to(output_device)
    attn_mask.to(output_device)

    encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
    return encoded_text
+2 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.7',
  version = '0.0.8',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',
@@ -24,6 +24,7 @@ setup(
    'fairseq',
    'joblib',
    'torch>=1.6',
    'transformers',
    'vector-quantize-pytorch>=0.10.5'
  ],
  classifiers=[