Commit b29141f8 authored by Phil Wang's avatar Phil Wang
Browse files

use 2d dynamic positional bias for fine transformer, to try to improve...

use 2d dynamic positional bias for fine transformer, to try to improve training at greater number of fine quantizers
parent 9ba1b040
Loading
Loading
Loading
Loading
+19 −15
Original line number Diff line number Diff line
@@ -822,8 +822,7 @@ class FineTransformer(nn.Module):

        pos_bias_mlp_dim = dim // 2
        self.pos_bias_mlp = nn.Sequential(
            Rearrange('... -> ... 1'),
            nn.Linear(1, pos_bias_mlp_dim),
            nn.Linear(2, pos_bias_mlp_dim),
            nn.SiLU(),
            nn.Linear(pos_bias_mlp_dim, pos_bias_mlp_dim),
            nn.SiLU(),
@@ -895,18 +894,18 @@ class FineTransformer(nn.Module):
        b, n = coarse_token_ids.shape

        coarse_length = coarse_token_ids.shape[-1]
        coarse_offsets = self.codebook_size * torch.arange(self.num_coarse_quantizers, device = device)
        coarse_offsets = torch.arange(self.num_coarse_quantizers, device = device)
        coarse_seq_length = ceil_div(coarse_token_ids.shape[-1], self.num_coarse_quantizers)
        coarse_offsets = repeat(coarse_offsets, 'q -> 1 (n q)', n = coarse_seq_length)
        coarse_offsets = coarse_offsets[:, :coarse_length]
        coarse_token_ids = coarse_token_ids + coarse_offsets
        coarse_offsets = repeat(coarse_offsets, 'q -> (n q)', n = coarse_seq_length)
        coarse_offsets = coarse_offsets[:coarse_length]
        coarse_token_ids = coarse_token_ids + rearrange(coarse_offsets, '... -> 1 ...') * self.codebook_size

        fine_length = fine_token_ids.shape[-1]
        fine_offsets = self.codebook_size * torch.arange(self.num_fine_quantizers, device = device)
        fine_offsets = torch.arange(self.num_fine_quantizers, device = device)
        fine_seq_length = ceil_div(fine_token_ids.shape[-1], self.num_fine_quantizers)
        fine_offsets = repeat(fine_offsets, 'q -> 1 (n q)', n = fine_seq_length)
        fine_offsets = fine_offsets[:, :fine_length]
        fine_token_ids = fine_token_ids + fine_offsets
        fine_offsets = repeat(fine_offsets, 'q -> (n q)', n = fine_seq_length)
        fine_offsets = fine_offsets[:fine_length]
        fine_token_ids = fine_token_ids + rearrange(fine_offsets, '... -> 1 ...') * self.codebook_size

        coarse_tokens = self.coarse_embedding(coarse_token_ids)
        fine_tokens = self.fine_embedding(fine_token_ids)
@@ -944,13 +943,18 @@ class FineTransformer(nn.Module):

        seq_positions = torch.cat((coarse_pos, fine_pos), dim = -1)

        rel_dist = (rearrange(seq_positions, 'i -> i 1') - rearrange(seq_positions, 'j -> 1 j'))
        rel_dist = rel_dist + max_seq_len # offset so all positive indices
        coarse_offsets = F.pad(coarse_offsets, (1, 0), value = 0)
        fine_offsets = fine_offsets + self.num_coarse_quantizers
        fine_offsets = F.pad(fine_offsets, (1, 0), value = 0)

        mlp_inp = torch.arange(-max_seq_len, max_seq_len + 1, device = device).float()
        attn_bias = self.pos_bias_mlp(mlp_inp)
        seq_offsets = torch.cat((coarse_offsets, fine_offsets), dim = -1)

        attn_bias = rearrange(attn_bias[rel_dist], '... h -> h ...')
        pos_mlp_input = torch.stack((seq_positions, seq_offsets), dim = -1)
        rel_dist = (rearrange(pos_mlp_input, 'i c -> i 1 c') - rearrange(pos_mlp_input, 'j c -> 1 j c'))

        attn_bias = self.pos_bias_mlp(rel_dist.float())

        attn_bias = rearrange(attn_bias, '... h -> h ...')

        # need to make sure start token has a custom positional bias

+1 −1
Original line number Diff line number Diff line
__version__ = '0.22.3'
__version__ = '0.23.2'