Commit 25b9a319 authored by dzy7e's avatar dzy7e
Browse files

update

parent a77587ce
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ class CaformerBackbone(torch.nn.Module):
        return x


def get_caformer(input_resolution: int = 384, heads: int = 32, feat_dims: int = 1024, **kwargs):
def get_caformer(input_resolution: int = 384, heads: int = 8, feat_dims: int = 768, **kwargs):
    transform = [
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ]
+1 −0
Original line number Diff line number Diff line
@@ -73,5 +73,6 @@ if __name__ == '__main__':

    data = torch.randn(4,3,384,384).cuda()
    model = CCIP('caformer').cuda()
    print(model.feature.backbone.attnpool.num_heads)
    print(model(data))