Commit 654d4969 authored by Phil Wang's avatar Phil Wang
Browse files

allow for prefix-based self attention conditioning, as was done in VALL-E, and...

allow for prefix-based self attention conditioning, as was done in VALL-E, and as requested by researcher @eonglints
parent 5058cc44
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -190,7 +190,8 @@ semantic_transformer = SemanticTransformer(
    num_semantic_tokens = 500,
    dim = 1024,
    depth = 6,
    has_condition = True  # this will have to be set to True
    has_condition = True,               # this will have to be set to True
    cond_as_self_attn_prefix = True     # whether to condition as prefix to self attention, instead of cross attention, as was done in 'VALL-E' paper
).cuda()

# mock text video dataset (as an example)
+78 −9
Original line number Diff line number Diff line
@@ -274,28 +274,58 @@ class Attention(nn.Module):
        x,
        context = None,
        mask = None,
        attn_bias = None
        attn_bias = None,
        prefix_context = None,
        prefix_context_mask = None
    ):
        b = x.shape[0]
        b, n, _, device = *x.shape, x.device

        if exists(context):
            context = self.context_norm(context)

        kv_input = default(context, x)

        # take care of prefix-based self attention conditioning
        # make sure to either concat the to the self attention mask or lengthen it accordingly

        if exists(prefix_context):
            kv_input = torch.cat((prefix_context, kv_input), dim = -2)
            prefix_seq_len = prefix_context.shape[-2]

            if not exists(mask):
                mask = torch.ones((b, n), device = device, dtype = torch.bool)

            if exists(prefix_context_mask):
                mask = torch.cat((prefix_context_mask, mask), dim = -1)
            else:
                mask = F.pad(mask, (prefix_seq_len, 0), value = True)

            if exists(attn_bias):
                attn_bias = F.pad(attn_bias, (prefix_seq_len, 0), value = 0.)

        # prenorm

        x = self.norm(x)

        # project for queries, keys, values

        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)

        # null key / values

        if self.num_null_kv > 0:
            null_k, null_v = repeat(self.null_kv, 'kv n d -> kv b n d', b = b).unbind(dim = 0)
            k = torch.cat((null_k, k), dim = -2)
            v = torch.cat((null_v, v), dim = -2)

        # split for multi-headed attention

        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        q = q * self.scale

        # similarities

        sim = einsum('b h i d, b j d -> b h i j', q, k)

        if exists(attn_bias):
@@ -312,11 +342,17 @@ class Attention(nn.Module):
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # attention

        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        # aggregate

        out = einsum('b h i j, b j d -> b h i d', attn, v)

        # merge heads

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

@@ -334,9 +370,13 @@ class Transformer(nn.Module):
        attn_dropout = 0.,
        ff_dropout = 0.,
        grad_shrink_alpha = 0.1,
        cond_as_self_attn_prefix = False,
        **kwargs
    ):
        super().__init__()
        assert not (cross_attend and cond_as_self_attn_prefix)
        self.cond_as_self_attn_prefix = cond_as_self_attn_prefix

        self.grad_shrink = partial(grad_shrink, alpha = grad_shrink_alpha)

        self.layers = nn.ModuleList([])
@@ -359,14 +399,23 @@ class Transformer(nn.Module):
        context = None,
        context_mask = None
    ):
        assert not (self.cond_as_self_attn_prefix and not exists(context))

        n, device = x.shape[1], x.device

        x = self.grad_shrink(x) # from cogview paper, adopted by GLM 130B LLM, decreases likelihood of attention net instability

        rel_pos_bias = self.rel_pos_bias(n, n, device = device)

        self_attn_kwargs = dict()
        if self.cond_as_self_attn_prefix:
            self_attn_kwargs = dict(
                prefix_context = context,
                prefix_context_mask = context_mask
            )

        for attn, cross_attn, ff in self.layers:
            x = attn(x, attn_bias = rel_pos_bias, mask = self_attn_mask) + x
            x = attn(x, attn_bias = rel_pos_bias, mask = self_attn_mask, **self_attn_kwargs) + x

            if exists(cross_attn):
                assert exists(context)
@@ -392,6 +441,7 @@ class SemanticTransformer(nn.Module):
        ff_dropout = 0.,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_as_self_attn_prefix = False,
        cond_drop_prob = 0.5,
        grad_shrink_alpha = 0.1,
        pad_id = -1,
@@ -410,14 +460,17 @@ class SemanticTransformer(nn.Module):
        self.eos_id = num_semantic_tokens
        self.pad_id = pad_id

        text_dim = get_encoded_dim(t5_name)
        self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity()

        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            dim_context = get_encoded_dim(t5_name),
            cross_attend = has_condition,
            cross_attend = has_condition and not cond_as_self_attn_prefix,
            cond_as_self_attn_prefix = cond_as_self_attn_prefix,
            grad_shrink_alpha = grad_shrink_alpha,
            **kwargs
        )
@@ -466,6 +519,9 @@ class SemanticTransformer(nn.Module):
                text_embeds = self.embed_text(text, output_device = device)
                text_mask = torch.any(text_embeds != 0, dim = -1)

        if exists(text_embeds):
            text_embeds = self.proj_text_embed(text_embeds)

        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        if exists(text_mask) and cond_drop_prob > 0:
@@ -502,6 +558,7 @@ class CoarseTransformer(nn.Module):
        ff_dropout = 0.,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_as_self_attn_prefix = False,
        cond_drop_prob = 0.5,
        grad_shrink_alpha = 0.1,
        **kwargs
@@ -523,14 +580,17 @@ class CoarseTransformer(nn.Module):
        codebook_size_with_eos = codebook_size + 1
        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim)

        text_dim = get_encoded_dim(t5_name)
        self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity()

        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            dim_context = get_encoded_dim(t5_name),
            cross_attend = has_condition,
            cross_attend = has_condition and not cond_as_self_attn_prefix,
            cond_as_self_attn_prefix = cond_as_self_attn_prefix,
            grad_shrink_alpha = grad_shrink_alpha,
            **kwargs
        )
@@ -589,6 +649,8 @@ class CoarseTransformer(nn.Module):
        if exists(text_embeds):
            text_mask = torch.any(text_embeds != 0, dim = -1)

            text_embeds = self.proj_text_embed(text_embeds)

        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        if exists(text_mask) and cond_drop_prob > 0:
@@ -663,6 +725,7 @@ class FineTransformer(nn.Module):
        ff_dropout = 0.,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_as_self_attn_prefix = False,
        cond_drop_prob = 0.5,
        grad_shrink_alpha = 0.1,
        **kwargs
@@ -684,14 +747,17 @@ class FineTransformer(nn.Module):

        self.eos_id = codebook_size

        text_dim = get_encoded_dim(t5_name)
        self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity()

        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            dim_context = get_encoded_dim(t5_name),
            cross_attend = has_condition,
            cross_attend = has_condition and not cond_as_self_attn_prefix,
            cond_as_self_attn_prefix = cond_as_self_attn_prefix,
            grad_shrink_alpha = grad_shrink_alpha,
            **kwargs
        )
@@ -747,6 +813,9 @@ class FineTransformer(nn.Module):
                text_embeds = self.embed_text(text, output_device = device)
                text_mask = torch.any(text_embeds != 0, dim = -1)

        if exists(text_embeds):
            text_embeds = self.proj_text_embed(text_embeds)

        cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

        if exists(text_mask) and cond_drop_prob > 0:
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.3.6',
  version = '0.4.0',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',