Loading zoo/ccip/attention_pool.py +1 −1 Original line number Diff line number Diff line Loading @@ -103,7 +103,7 @@ class AttentionPool2d_flat(nn.Module): self.num_heads = num_heads def forward(self, x): x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = torch.cat([x.mean(dim=0, keepdim=True), x[1:]], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x[:1], key=x, value=x, Loading zoo/ccip/caformer.py +2 −2 Original line number Diff line number Diff line Loading @@ -12,8 +12,8 @@ class CaformerBackbone(torch.nn.Module): self.caformer = CAFormerBuilder(**kwargs)() if pool_with_query: self.attnpool = nn.Sequential( AttentionPool2d_query(self.input_resolution // 32, self.caformer.output_dim, heads, out_dims, n_query=8), AttentionPool2d_flat(8, out_dims, heads, out_dims), AttentionPool2d_query(self.input_resolution // 32, self.caformer.output_dim, heads, self.caformer.output_dim, n_query=8), AttentionPool2d_flat(8, self.caformer.output_dim, heads, out_dims), ) else: self.attnpool = AttentionPool2d(self.input_resolution//32, self.caformer.output_dim, heads, out_dims) Loading Loading
zoo/ccip/attention_pool.py +1 −1 Original line number Diff line number Diff line Loading @@ -103,7 +103,7 @@ class AttentionPool2d_flat(nn.Module): self.num_heads = num_heads def forward(self, x): x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = torch.cat([x.mean(dim=0, keepdim=True), x[1:]], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x[:1], key=x, value=x, Loading
zoo/ccip/caformer.py +2 −2 Original line number Diff line number Diff line Loading @@ -12,8 +12,8 @@ class CaformerBackbone(torch.nn.Module): self.caformer = CAFormerBuilder(**kwargs)() if pool_with_query: self.attnpool = nn.Sequential( AttentionPool2d_query(self.input_resolution // 32, self.caformer.output_dim, heads, out_dims, n_query=8), AttentionPool2d_flat(8, out_dims, heads, out_dims), AttentionPool2d_query(self.input_resolution // 32, self.caformer.output_dim, heads, self.caformer.output_dim, n_query=8), AttentionPool2d_flat(8, self.caformer.output_dim, heads, out_dims), ) else: self.attnpool = AttentionPool2d(self.input_resolution//32, self.caformer.output_dim, heads, out_dims) Loading