Commit 0491eaaf authored by Phil Wang's avatar Phil Wang
Browse files

in coarse transformer, make sure that coarse tokens attending to semantic...

in coarse transformer, make sure that coarse tokens attending to semantic tokens (cross attention) does not use relative positions
parent dd4784ab
Loading
Loading
Loading
Loading
+34 −4
Original line number Diff line number Diff line
@@ -198,7 +198,12 @@ class RelativePositionBias(nn.Module):

        self.net.append(nn.Linear(dim, heads))

    def forward(self, n, device = torch.device('cpu')):
    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, n):
        device = self.device
        pos = torch.arange(n, device = device)
        rel_pos = (rearrange(pos, 'i -> i 1') - rearrange(pos, 'j -> 1 j'))
        rel_pos += (n - 1)
@@ -432,7 +437,7 @@ class Transformer(nn.Module):
        if exists(attn_bias):
            rel_pos_bias = attn_bias
        else:
            rel_pos_bias = maybe(self.rel_pos_bias)(n, device = device)
            rel_pos_bias = maybe(self.rel_pos_bias)(n)

        self_attn_kwargs = dict()
        if self.cond_as_self_attn_prefix:
@@ -623,6 +628,8 @@ class CoarseTransformer(nn.Module):
        text_dim = default(cond_dim, get_encoded_dim(t5_name))
        self.proj_text_embed = nn.Linear(text_dim, dim, bias = False) if text_dim != dim else nn.Identity()

        self.cross_attn_bias = nn.Parameter(torch.zeros(heads, 1, 1))

        self.transformer = Transformer(
            dim = dim,
            depth = depth,
@@ -677,6 +684,7 @@ class CoarseTransformer(nn.Module):
        return_only_coarse_logits = False
    ):
        b, device = semantic_token_ids.shape[0], semantic_token_ids.device
        arange = partial(torch.arange, device = device)

        has_text = exists(text) or exists(text_embeds)
        assert not (self.has_condition ^ has_text)
@@ -699,7 +707,7 @@ class CoarseTransformer(nn.Module):

        coarse_token_ids, semantic_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, semantic_token_ids))

        offsets = self.codebook_size * torch.arange(self.num_coarse_quantizers, device = device)
        offsets = self.codebook_size * arange(self.num_coarse_quantizers)
        offsets = repeat(offsets, 'q -> 1 (n q)', n = ceil_div(coarse_token_ids.shape[-1], self.num_coarse_quantizers))
        offsets = offsets[:, :coarse_token_ids.shape[-1]]
        coarse_token_ids = coarse_token_ids + offsets
@@ -723,7 +731,29 @@ class CoarseTransformer(nn.Module):
            coarse_tokens
        ), dim = 1)

        tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask)
        # engineer the attention bias so that cross attention is not dominated by relative positions

        seq_len = tokens.shape[-2]
        attn_bias = self.transformer.rel_pos_bias(seq_len)

        is_semantic = arange(seq_len) < (semantic_seq_len + 1) # semantic seq len + start token
        is_cross_attn = rearrange(is_semantic, 'i -> i 1') ^ rearrange(is_semantic, 'j -> 1 j')

        attn_bias = torch.where(
            is_cross_attn,
            self.cross_attn_bias,
            attn_bias
        )

        # attend

        tokens = self.transformer(
            tokens,
            context = text_embeds,
            attn_bias = attn_bias,
            self_attn_mask = self_attn_mask,
            context_mask = text_mask
        )

        pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, (semantic_seq_len + 1):]

+1 −1
Original line number Diff line number Diff line
__version__ = '0.23.7'
__version__ = '0.24.0'