Commit af903342 authored by Phil Wang's avatar Phil Wang
Browse files

create specially engineered relative positional bias for fine transformer, so...

create specially engineered relative positional bias for fine transformer, so coarse and fine sequences learn to attend to each other at relative distances apart
parent 7a35a8bc
Loading
Loading
Loading
Loading
+82 −8
Original line number Diff line number Diff line
@@ -33,7 +33,15 @@ def exists(val):
def default(val, d):
    return val if exists(val) else d

def always(val):
    def inner(*args, **kwargs):
        return val
    return inner

def maybe(fn):
    if not exists(fn):
        return always(None)

    @wraps(fn)
    def inner(x, *args, **kwargs):
        if not exists(x):
@@ -382,6 +390,7 @@ class Transformer(nn.Module):
        ff_dropout = 0.,
        grad_shrink_alpha = 0.1,
        cond_as_self_attn_prefix = False,
        rel_pos_bias = True,
        **kwargs
    ):
        super().__init__()
@@ -394,7 +403,7 @@ class Transformer(nn.Module):

        self.layers = nn.ModuleList([])

        self.rel_pos_bias = RelativePositionBias(dim = dim // 2, heads = heads)
        self.rel_pos_bias = RelativePositionBias(dim = dim // 2, heads = heads) if rel_pos_bias else None

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
@@ -410,7 +419,8 @@ class Transformer(nn.Module):
        x,
        self_attn_mask = None,
        context = None,
        context_mask = None
        context_mask = None,
        attn_bias = None
    ):
        assert not (self.cond_as_self_attn_prefix and not exists(context))
        assert not (exists(context) and context.shape[-1] != self.dim_context), f'you had specified a conditioning dimension of {self.dim_context}, yet what was received by the transformer has dimension of {context.shape[-1]}'
@@ -419,7 +429,10 @@ class Transformer(nn.Module):

        x = self.grad_shrink(x) # from cogview paper, adopted by GLM 130B LLM, decreases likelihood of attention net instability

        rel_pos_bias = self.rel_pos_bias(n, device = device)
        if exists(attn_bias):
            rel_pos_bias = attn_bias
        else:
            rel_pos_bias = maybe(self.rel_pos_bias)(n, device = device)

        self_attn_kwargs = dict()
        if self.cond_as_self_attn_prefix:
@@ -798,10 +811,25 @@ class FineTransformer(nn.Module):
            ff_dropout = ff_dropout,
            cross_attend = has_condition and not cond_as_self_attn_prefix,
            cond_as_self_attn_prefix = cond_as_self_attn_prefix,
            rel_pos_bias = False,
            grad_shrink_alpha = grad_shrink_alpha,
            **kwargs
        )

        # doing a specialized attn bias so that corresponding time steps at fine and coarse sequences attend to each other better

        self.null_pos_bias = nn.Parameter(torch.randn(heads, 1, 1))

        pos_bias_mlp_dim = dim // 2
        self.pos_bias_mlp = nn.Sequential(
            Rearrange('... -> ... 1'),
            nn.Linear(1, pos_bias_mlp_dim),
            nn.SiLU(),
            nn.Linear(pos_bias_mlp_dim, pos_bias_mlp_dim),
            nn.SiLU(),
            nn.Linear(pos_bias_mlp_dim, heads)
        )

        self.codebook_size = codebook_size
        self.num_coarse_quantizers = num_coarse_quantizers
        self.num_fine_quantizers = num_fine_quantizers
@@ -866,14 +894,18 @@ class FineTransformer(nn.Module):

        b, n = coarse_token_ids.shape

        coarse_length = coarse_token_ids.shape[-1]
        coarse_offsets = self.codebook_size * torch.arange(self.num_coarse_quantizers, device = device)
        coarse_offsets = repeat(coarse_offsets, 'q -> 1 (n q)', n = ceil_div(coarse_token_ids.shape[-1], self.num_coarse_quantizers))
        coarse_offsets = coarse_offsets[:, :coarse_token_ids.shape[-1]]
        coarse_seq_length = ceil_div(coarse_token_ids.shape[-1], self.num_coarse_quantizers)
        coarse_offsets = repeat(coarse_offsets, 'q -> 1 (n q)', n = coarse_seq_length)
        coarse_offsets = coarse_offsets[:, :coarse_length]
        coarse_token_ids = coarse_token_ids + coarse_offsets

        fine_length = fine_token_ids.shape[-1]
        fine_offsets = self.codebook_size * torch.arange(self.num_fine_quantizers, device = device)
        fine_offsets = repeat(fine_offsets, 'q -> 1 (n q)', n = ceil_div(fine_token_ids.shape[-1], self.num_fine_quantizers))
        fine_offsets = fine_offsets[:, :fine_token_ids.shape[-1]]
        fine_seq_length = ceil_div(fine_token_ids.shape[-1], self.num_fine_quantizers)
        fine_offsets = repeat(fine_offsets, 'q -> 1 (n q)', n = fine_seq_length)
        fine_offsets = fine_offsets[:, :fine_length]
        fine_token_ids = fine_token_ids + fine_offsets

        coarse_tokens = self.coarse_embedding(coarse_token_ids)
@@ -897,7 +929,49 @@ class FineTransformer(nn.Module):
            fine_tokens
        ), dim = 1)

        tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask)
        # an engineered attention bias so coarse and fine sequences attend to each other better

        max_seq_len = max(coarse_seq_length, fine_seq_length)

        coarse_pos = torch.arange(coarse_seq_length, device = device)
        fine_pos = torch.arange(fine_seq_length, device = device)

        coarse_pos = repeat(coarse_pos, 'n -> (n q)', q = self.num_coarse_quantizers)[:coarse_length]
        fine_pos = repeat(fine_pos, 'n -> (n q)', q = self.num_fine_quantizers)[:fine_length]

        coarse_pos = F.pad(coarse_pos, (1, 0), value = -1) # -1 for start token
        fine_pos = F.pad(fine_pos, (1, 0), value = -1)

        seq_positions = torch.cat((coarse_pos, fine_pos), dim = -1)

        rel_dist = (rearrange(seq_positions, 'i -> i 1') - rearrange(seq_positions, 'j -> 1 j'))
        rel_dist = rel_dist + max_seq_len # offset so all positive indices

        mlp_inp = torch.arange(-max_seq_len, max_seq_len + 1, device = device).float()
        attn_bias = self.pos_bias_mlp(mlp_inp)

        attn_bias = rearrange(attn_bias[rel_dist], '... h -> h ...')

        # need to make sure start token has a custom positional bias

        is_start_token_seq = seq_positions == -1
        start_token_mask = rearrange(is_start_token_seq, 'i -> i 1') | rearrange(is_start_token_seq, 'j -> 1 j')

        attn_bias = torch.where(
            start_token_mask,
            self.null_pos_bias,
            attn_bias,
        )

        # attention

        tokens = self.transformer(
            tokens,
            context = text_embeds,
            self_attn_mask = self_attn_mask,
            context_mask = text_mask,
            attn_bias = attn_bias
        )

        pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, (n + 1):]

+1 −1
Original line number Diff line number Diff line
__version__ = '0.21.4'
__version__ = '0.22.0'