Commit a77587ce authored by dzy7e's avatar dzy7e
Browse files

update

parent dd02ffdf
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ from ..monochrome.metaformer import CAFormerBuilder


class CaformerBackbone(torch.nn.Module):
    def __init__(self, input_resolution: int = 384, heads: int = 32, out_dims: int = 1024, **kwargs):
    def __init__(self, input_resolution: int = 384, heads: int = 8, out_dims: int = 768, **kwargs):
        torch.nn.Module.__init__(self)
        self.input_resolution = input_resolution
        self.caformer = CAFormerBuilder(**kwargs)()
+5 −5
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ from .backbone import get_backbone

class CCIPBatchMetrics(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        super().__init__()
        self.logit_scale = nn.Parameter(torch.ones([])*np.log(1/0.07))
        #self.sim = nn.CosineSimilarity(dim=-1)

@@ -27,9 +27,9 @@ class CCIPBatchMetrics(nn.Module):
        return logits_per_image


class CCIPFeature(torch.nn.Module):
class CCIPFeature(nn.Module):
    def __init__(self, name: str = "clip/ViT-B/32"):
        torch.nn.Module.__init__(self)
        super().__init__()
        self.backbone, self.preprocess = get_backbone(name)

    def forward(self, x):
@@ -38,9 +38,9 @@ class CCIPFeature(torch.nn.Module):
        return x


class CCIP(torch.nn.Module):
class CCIP(nn.Module):
    def __init__(self, name: str = "clip/ViT-B/32"):
        torch.nn.Module.__init__(self)
        super().__init__()
        self.feature = CCIPFeature(name)
        self.metrics = CCIPBatchMetrics()