Commit 52d34a91 authored by dzy7e's avatar dzy7e
Browse files

attnpool_query

parent e81035d7
Loading
Loading
Loading
Loading
+84 −1
Original line number Diff line number Diff line
import torch
import torch.nn.functional as F
from torch import nn

from einops import repeat

class AttentionPool2d(nn.Module):
    """
@@ -42,3 +42,86 @@ class AttentionPool2d(nn.Module):
            need_weights=False
        )
        return x.squeeze(0)

class AttentionPool2d_query(nn.Module):
    """
    If the CNN's output is (1, 2048, 7, 7), then the parameters should be (7, 2048, 32, 1024),
    so if the CNN's output is (1, 512, 12, 12), then it should be (12, 512, 32?, 1024).
    """

    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, n_query=8):
        super().__init__()
        self.query_emb = nn.Parameter(torch.randn(n_query, 1, embed_dim) / embed_dim ** 0.5)

        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.flatten(start_dim=2).permute(2, 0, 1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        q = torch.cat((x[:1], repeat(self.query_emb, 'hw 1 c -> hw n c', n=x.shape[1])), dim=0)

        x, _ = F.multi_head_attention_forward(
            query=q, key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        ) # [N_q+1, N, C]
        return x

class AttentionPool2d_flat(nn.Module):
    """
    If the CNN's output is (1, 2048, 7, 7), then the parameters should be (7, 2048, 32, 1024),
    so if the CNN's output is (1, 512, 12, 12), then it should be (12, 512, 32?, 1024).
    """

    def __init__(self, n_token: int, embed_dim: int, num_heads: int, output_dim: int = None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(n_token + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x[:1], key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )
        return x.squeeze(0)
 No newline at end of file
+2 −1
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import clip
import torch
from torchvision.transforms import Compose

from .caformer import get_caformer, get_caformer_s18
from .caformer import get_caformer, get_caformer_s18, get_caformer_query


def get_clip_backbone(name="ViT-B/32") -> Tuple[torch.nn.Module, Compose]:
@@ -22,6 +22,7 @@ def register_backbone(name, func, *args, **kwargs):


register_backbone('caformer', get_caformer)
register_backbone('caformer_query', get_caformer_query)
register_backbone('caformer_s18', get_caformer_s18)


+17 −4
Original line number Diff line number Diff line
import torch.nn
from torchvision.transforms import Normalize

from .attention_pool import AttentionPool2d
from .attention_pool import AttentionPool2d, AttentionPool2d_query, AttentionPool2d_flat
from ..monochrome.metaformer import CAFormerBuilder

from torch import nn

class CaformerBackbone(torch.nn.Module):
    def __init__(self, input_resolution: int = 384, heads: int = 8, out_dims: int = 768, **kwargs):
    def __init__(self, input_resolution: int = 384, heads: int = 8, out_dims: int = 768, pool_with_query=False, **kwargs):
        torch.nn.Module.__init__(self)
        self.input_resolution = input_resolution
        self.caformer = CAFormerBuilder(**kwargs)()
        if pool_with_query:
            self.attnpool = nn.Sequential(
                AttentionPool2d_query(self.input_resolution // 32, self.caformer.output_dim, heads, out_dims, n_query=8),
                AttentionPool2d_flat(8, out_dims, heads, out_dims),
            )
        else:
            self.attnpool = AttentionPool2d(self.input_resolution//32, self.caformer.output_dim, heads, out_dims)

    def _get_cnn_result(self, x):
@@ -33,6 +39,13 @@ def get_caformer(input_resolution: int = 384, heads: int = 8, feat_dims: int = 7

    return CaformerBackbone(input_resolution, heads, feat_dims, **kwargs), transform

def get_caformer_query(input_resolution: int = 384, heads: int = 8, feat_dims: int = 768, **kwargs):
    transform = [
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ]

    return CaformerBackbone(input_resolution, heads, feat_dims, pool_with_query=True, **kwargs), transform

def get_caformer_s18(input_resolution: int = 384, heads: int = 8, feat_dims: int = 768, **kwargs):
    transform = [
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+44 −30
Original line number Diff line number Diff line
@@ -29,10 +29,28 @@ class WeakRandAugment(transforms.RandAugment):
            "Equalize":(torch.tensor(0.0), False),
        }

class WeakRandAugment2(transforms.RandAugment):
    def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[torch.Tensor, bool]]:
        return {
            # op_name: (magnitudes, signed)
            "Identity":(torch.tensor(0.0), False),
            "ShearX":(torch.linspace(0.0, 0.08, num_bins), True),
            "ShearY":(torch.linspace(0.0, 0.08, num_bins), True),
            "TranslateX":(torch.linspace(0.0, 0.08*image_size[1], num_bins), True),
            "TranslateY":(torch.linspace(0.0, 0.08*image_size[0], num_bins), True),
            "Rotate":(torch.linspace(0.0, 8.0, num_bins), True),
            "Brightness":(torch.linspace(0.0, 0.05, num_bins), True),
            "Contrast":(torch.linspace(0.0, 0.05, num_bins), True),
            "Sharpness":(torch.linspace(0.0, 0.2, num_bins), True),
            "Posterize":(8-(torch.arange(num_bins)/((num_bins-1)/4)).round().int(), False),
            "AutoContrast":(torch.tensor(0.0), False),
            "Equalize":(torch.tensor(0.0), False),
        }

TRAIN_TRANSFORM = [
    transforms.Resize((416, 416)),
    transforms.RandomHorizontalFlip(),
    WeakRandAugment(),
    WeakRandAugment2(),
    transforms.RandomCrop(384),
    transforms.ToTensor(),
]
@@ -42,7 +60,6 @@ TEST_TRANSFORM = [
    transforms.ToTensor(),
]


class ImagesDataset(Dataset):
    def __init__(self, items: List[Tuple[str, int]], transform=None):
        self.items: List[Tuple[str, int]] = items
@@ -73,7 +90,6 @@ class ImagesDataset(Dataset):
        return ImagesDataset(train_items, train_transform or self.transform), \
            ImagesDataset(test_items, test_transform or self.transform)


class CCIPImagesDataset(ImagesDataset):
    def __init__(self, root_dir, transform=None):
        _maxid = 0
@@ -88,7 +104,6 @@ class CCIPImagesDataset(ImagesDataset):

        super(CCIPImagesDataset, self).__init__(_items, transform)


class CharacterDataset(Dataset):
    def __init__(self, images_dataset: ImagesDataset, group_size: int = 100,
                 prob: float = 0.5, force_prob: bool = True):
@@ -140,7 +155,6 @@ class CharacterDataset(Dataset):
        return torch.stack(list(map(torch.as_tensor, images))), \
            torch.stack(list(map(torch.as_tensor, labels)))


class FastCharacterDataset(Dataset):
    def __init__(self, images_dataset: ImagesDataset, group_size: int = 100,
                 prob: float = 0.5, **kwargs):
@@ -162,7 +176,8 @@ class FastCharacterDataset(Dataset):
    def reset(self):
        idxs = np.arange(0, len(self.images_dataset.items))
        np.random.shuffle(idxs)
        self.idxs = idxs[:-(len(idxs) % self.group_size)]
        rest = len(idxs)%self.group_size
        self.idxs = idxs[:-rest] if rest>0 else idxs

    def __len__(self):
        return len(self.idxs)
@@ -182,7 +197,6 @@ class FastCharacterDataset(Dataset):

        return image, cid


def char_collect_fn(batch):
    img_list, cid_list = [], []
    for data in batch:
+1 −1
Original line number Diff line number Diff line
@@ -145,7 +145,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        steps_per_epoch=len(train_dataloader)//accelerator.num_processes, epochs=max_epochs,
        pct_start=0.15, final_div_factor=20.
    )
    # model = torch.compile(model)
    model = torch.compile(model)

    model, optimizer, train_dataloader, test_dataloader, scheduler = \
        accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler)