Commit 7413682e authored by dzy7e's avatar dzy7e
Browse files

attnpool_query

parent 52d34a91
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -103,7 +103,7 @@ class AttentionPool2d_flat(nn.Module):
        self.num_heads = num_heads

    def forward(self, x):
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x[1:]], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x[:1], key=x, value=x,
+2 −2
Original line number Diff line number Diff line
@@ -12,8 +12,8 @@ class CaformerBackbone(torch.nn.Module):
        self.caformer = CAFormerBuilder(**kwargs)()
        if pool_with_query:
            self.attnpool = nn.Sequential(
                AttentionPool2d_query(self.input_resolution // 32, self.caformer.output_dim, heads, out_dims, n_query=8),
                AttentionPool2d_flat(8, out_dims, heads, out_dims),
                AttentionPool2d_query(self.input_resolution // 32, self.caformer.output_dim, heads, self.caformer.output_dim, n_query=8),
                AttentionPool2d_flat(8, self.caformer.output_dim, heads, out_dims),
            )
        else:
            self.attnpool = AttentionPool2d(self.input_resolution//32, self.caformer.output_dim, heads, out_dims)