Commit 3aa64352 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add atp and caformer for ccip

parent 03aabb9b
Loading
Loading
Loading
Loading
+44 −0
Original line number Diff line number Diff line
import torch
import torch.nn.functional as F
from torch import nn


class AttentionPool2d(nn.Module):
    """
    If the CNN's output is (1, 2048, 7, 7), then the parameters should be (7, 2048, 32, 1024),
    so if the CNN's output is (1, 512, 12, 12), then it should be (12, 512, 32?, 1024).
    """

    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.flatten(start_dim=2).permute(2, 0, 1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x[:1], key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )
        return x.squeeze(0)
+5 −0
Original line number Diff line number Diff line
@@ -5,6 +5,8 @@ import clip
import torch
from torchvision.transforms import Compose

from .caformer import get_caformer


def get_clip_backbone(name="ViT-B/32") -> Tuple[torch.nn.Module, Compose]:
    model, preprocess = clip.load(name, device='cpu')
@@ -19,6 +21,9 @@ def register_backbone(name, func, *args, **kwargs):
    _KNOWN_BACKBONES[name] = partial(func, *args, **kwargs)


register_backbone('caformer', get_caformer)


def get_backbone(name: str) -> Tuple[torch.nn.Module, Compose]:
    if name.startswith(CLIP_PREFIX):
        clip_name = name[len(CLIP_PREFIX):]

zoo/ccip/caformer.py

0 → 100644
+38 −0
Original line number Diff line number Diff line
import torch.nn
from torchvision.transforms import InterpolationMode, Compose, Resize, CenterCrop, ToTensor, Normalize

from .attention_pool import AttentionPool2d
from ..monochrome.metaformer import CAFormerBuilder


class CaformerBackbone(torch.nn.Module):
    def __init__(self, input_resolution: int = 384, heads: int = 32, out_dims: int = 1024, **kwargs):
        torch.nn.Module.__init__(self)
        self.input_resolution = input_resolution
        self.caformer = CAFormerBuilder(**kwargs)()
        self.attnpool = AttentionPool2d(self.input_resolution // 32, self.caformer.output_dim, heads, out_dims)

    def _get_cnn_result(self, x):
        for i in range(self.caformer.num_stage):
            x = self.caformer.downsample_layers[i](x)
            x = self.caformer.stages[i](x)

        x = x.permute(0, 3, 1, 2)  # BxHxWxC --> BxCxHxW
        return x

    def forward(self, x):
        x = self._get_cnn_result(x)
        x = self.attnpool(x)
        return x


def get_caformer(input_resolution: int = 224, heads: int = 32, feat_dims: int = 1024, **kwargs):
    transform = Compose([
        Resize(input_resolution, interpolation=InterpolationMode.BICUBIC),
        CenterCrop(input_resolution),
        lambda x: x.convert('RGB'),
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

    return CaformerBackbone(input_resolution, heads, feat_dims, **kwargs), transform
+1 −0
Original line number Diff line number Diff line
@@ -627,6 +627,7 @@ class MetaFormer(nn.Module):
            self.head = head_fn(dims[-1], num_classes)

        self.apply(self._init_weights)
        self.output_dim = dims[-1]

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):