Commit 3c639577 authored by Phil Wang's avatar Phil Wang
Browse files

fix the memory issue with the 2d relative attention bias (seq and quantizer...

fix the memory issue with the 2d relative attention bias (seq and quantizer positions) in the fine transformer
parent ff608683
Loading
Loading
Loading
Loading
+36 −3
Original line number Diff line number Diff line
@@ -938,7 +938,7 @@ class FineTransformer(nn.Module):
        coarse_pos = repeat(coarse_pos, 'n -> (n q)', q = self.num_coarse_quantizers)[:coarse_length]
        fine_pos = repeat(fine_pos, 'n -> (n q)', q = self.num_fine_quantizers)[:fine_length]

        coarse_pos = F.pad(coarse_pos, (1, 0), value = -1) # -1 for start token
        coarse_pos = F.pad(coarse_pos, (1, 0), value = -1)
        fine_pos = F.pad(fine_pos, (1, 0), value = -1)

        seq_positions = torch.cat((coarse_pos, fine_pos), dim = -1)
@@ -949,10 +949,43 @@ class FineTransformer(nn.Module):

        seq_offsets = torch.cat((coarse_offsets, fine_offsets), dim = -1)

        pos_mlp_input = torch.stack((seq_positions, seq_offsets), dim = -1)
        pos_mlp_input = torch.stack((seq_positions.clamp(min = 0), seq_offsets), dim = -1)

        num_offsets = self.num_fine_quantizers + self.num_coarse_quantizers

        # relative positions are always (2 * N - 1), where N is the length of the dimension

        rel_seq_len, rel_offsets = map(lambda n: 2 * n - 1, (max_seq_len, num_offsets))

        # get all relative distances

        rel_dist = (rearrange(pos_mlp_input, 'i c -> i 1 c') - rearrange(pos_mlp_input, 'j c -> 1 j c'))

        attn_bias = self.pos_bias_mlp(rel_dist.float())
        # get all possible relative distances for the attention bias to be computed from the mlp
        # which would be - (2 * N - 1) * (2 * Q - 1) - where N = sequence length and Q = total quantizers

        rel_seq_len_range = repeat(torch.arange(rel_seq_len, device = device), 'n -> (n q)', q = rel_offsets)
        rel_offset_range = repeat(torch.arange(rel_offsets, device = device), 'q -> (n q)', n = rel_seq_len)

        mlp_inputs = torch.stack((rel_seq_len_range, rel_offset_range), dim = -1)

        # implicitly parameterized relative distances, by sequence and quantizer positions

        attn_bias = self.pos_bias_mlp(mlp_inputs.float())

        # translate coordinates of (rel_seq_pos, rel_quantizer_offset) -> positive index to select from attn bias

        rel_dist_seq_pos, rel_dist_seq_offset = rel_dist.unbind(dim = -1)

        rel_dist_seq_pos += max_seq_len - 1
        rel_dist_seq_offset += num_offsets - 1

        rel_dist_indices = rel_dist_seq_pos * rel_offsets + rel_dist_seq_offset

        # select the relative positional attention bias outputted by the MLP
        # savings go from (N * Q) ^ 2 -> ~ (4 * N * Q)

        attn_bias = attn_bias[rel_dist_indices]

        attn_bias = rearrange(attn_bias, '... h -> h ...')

+1 −1
Original line number Diff line number Diff line
__version__ = '0.23.3'
__version__ = '0.23.5'