Loading zoo/ccip/attention_pool.py +84 −1 Original line number Diff line number Diff line import torch import torch.nn.functional as F from torch import nn from einops import repeat class AttentionPool2d(nn.Module): """ Loading Loading @@ -42,3 +42,86 @@ class AttentionPool2d(nn.Module): need_weights=False ) return x.squeeze(0) class AttentionPool2d_query(nn.Module): """ If the CNN's output is (1, 2048, 7, 7), then the parameters should be (7, 2048, 32, 1024), so if the CNN's output is (1, 512, 12, 12), then it should be (12, 512, 32?, 1024). """ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, n_query=8): super().__init__() self.query_emb = nn.Parameter(torch.randn(n_query, 1, embed_dim) / embed_dim ** 0.5) self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC q = torch.cat((x[:1], repeat(self.query_emb, 'hw 1 c -> hw n c', n=x.shape[1])), dim=0) x, _ = F.multi_head_attention_forward( query=q, key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False ) # [N_q+1, N, C] return x class AttentionPool2d_flat(nn.Module): """ If the CNN's output is (1, 2048, 7, 7), then the parameters should be (7, 2048, 32, 1024), so if the CNN's output is (1, 512, 12, 12), then it should be (12, 512, 32?, 1024). """ def __init__(self, n_token: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() self.positional_embedding = nn.Parameter(torch.randn(n_token + 1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x[:1], key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False ) return x.squeeze(0) No newline at end of file zoo/ccip/backbone.py +2 −1 Original line number Diff line number Diff line Loading @@ -5,7 +5,7 @@ import clip import torch from torchvision.transforms import Compose from .caformer import get_caformer, get_caformer_s18 from .caformer import get_caformer, get_caformer_s18, get_caformer_query def get_clip_backbone(name="ViT-B/32") -> Tuple[torch.nn.Module, Compose]: Loading @@ -22,6 +22,7 @@ def register_backbone(name, func, *args, **kwargs): register_backbone('caformer', get_caformer) register_backbone('caformer_query', get_caformer_query) register_backbone('caformer_s18', get_caformer_s18) Loading zoo/ccip/caformer.py +17 −4 Original line number Diff line number Diff line import torch.nn from torchvision.transforms import Normalize from .attention_pool import AttentionPool2d from .attention_pool import AttentionPool2d, AttentionPool2d_query, AttentionPool2d_flat from ..monochrome.metaformer import CAFormerBuilder from torch import nn class CaformerBackbone(torch.nn.Module): def __init__(self, input_resolution: int = 384, heads: int = 8, out_dims: int = 768, **kwargs): def __init__(self, input_resolution: int = 384, heads: int = 8, out_dims: int = 768, pool_with_query=False, **kwargs): torch.nn.Module.__init__(self) self.input_resolution = input_resolution self.caformer = CAFormerBuilder(**kwargs)() if pool_with_query: self.attnpool = nn.Sequential( AttentionPool2d_query(self.input_resolution // 32, self.caformer.output_dim, heads, out_dims, n_query=8), AttentionPool2d_flat(8, out_dims, heads, out_dims), ) else: self.attnpool = AttentionPool2d(self.input_resolution//32, self.caformer.output_dim, heads, out_dims) def _get_cnn_result(self, x): Loading @@ -33,6 +39,13 @@ def get_caformer(input_resolution: int = 384, heads: int = 8, feat_dims: int = 7 return CaformerBackbone(input_resolution, heads, feat_dims, **kwargs), transform def get_caformer_query(input_resolution: int = 384, heads: int = 8, feat_dims: int = 768, **kwargs): transform = [ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ] return CaformerBackbone(input_resolution, heads, feat_dims, pool_with_query=True, **kwargs), transform def get_caformer_s18(input_resolution: int = 384, heads: int = 8, feat_dims: int = 768, **kwargs): transform = [ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), Loading zoo/ccip/dataset.py +44 −30 Original line number Diff line number Diff line Loading @@ -29,10 +29,28 @@ class WeakRandAugment(transforms.RandAugment): "Equalize":(torch.tensor(0.0), False), } class WeakRandAugment2(transforms.RandAugment): def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[torch.Tensor, bool]]: return { # op_name: (magnitudes, signed) "Identity":(torch.tensor(0.0), False), "ShearX":(torch.linspace(0.0, 0.08, num_bins), True), "ShearY":(torch.linspace(0.0, 0.08, num_bins), True), "TranslateX":(torch.linspace(0.0, 0.08*image_size[1], num_bins), True), "TranslateY":(torch.linspace(0.0, 0.08*image_size[0], num_bins), True), "Rotate":(torch.linspace(0.0, 8.0, num_bins), True), "Brightness":(torch.linspace(0.0, 0.05, num_bins), True), "Contrast":(torch.linspace(0.0, 0.05, num_bins), True), "Sharpness":(torch.linspace(0.0, 0.2, num_bins), True), "Posterize":(8-(torch.arange(num_bins)/((num_bins-1)/4)).round().int(), False), "AutoContrast":(torch.tensor(0.0), False), "Equalize":(torch.tensor(0.0), False), } TRAIN_TRANSFORM = [ transforms.Resize((416, 416)), transforms.RandomHorizontalFlip(), WeakRandAugment(), WeakRandAugment2(), transforms.RandomCrop(384), transforms.ToTensor(), ] Loading @@ -42,7 +60,6 @@ TEST_TRANSFORM = [ transforms.ToTensor(), ] class ImagesDataset(Dataset): def __init__(self, items: List[Tuple[str, int]], transform=None): self.items: List[Tuple[str, int]] = items Loading Loading @@ -73,7 +90,6 @@ class ImagesDataset(Dataset): return ImagesDataset(train_items, train_transform or self.transform), \ ImagesDataset(test_items, test_transform or self.transform) class CCIPImagesDataset(ImagesDataset): def __init__(self, root_dir, transform=None): _maxid = 0 Loading @@ -88,7 +104,6 @@ class CCIPImagesDataset(ImagesDataset): super(CCIPImagesDataset, self).__init__(_items, transform) class CharacterDataset(Dataset): def __init__(self, images_dataset: ImagesDataset, group_size: int = 100, prob: float = 0.5, force_prob: bool = True): Loading Loading @@ -140,7 +155,6 @@ class CharacterDataset(Dataset): return torch.stack(list(map(torch.as_tensor, images))), \ torch.stack(list(map(torch.as_tensor, labels))) class FastCharacterDataset(Dataset): def __init__(self, images_dataset: ImagesDataset, group_size: int = 100, prob: float = 0.5, **kwargs): Loading @@ -162,7 +176,8 @@ class FastCharacterDataset(Dataset): def reset(self): idxs = np.arange(0, len(self.images_dataset.items)) np.random.shuffle(idxs) self.idxs = idxs[:-(len(idxs) % self.group_size)] rest = len(idxs)%self.group_size self.idxs = idxs[:-rest] if rest>0 else idxs def __len__(self): return len(self.idxs) Loading @@ -182,7 +197,6 @@ class FastCharacterDataset(Dataset): return image, cid def char_collect_fn(batch): img_list, cid_list = [], [] for data in batch: Loading zoo/ccip/train_.py +1 −1 Original line number Diff line number Diff line Loading @@ -145,7 +145,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio steps_per_epoch=len(train_dataloader)//accelerator.num_processes, epochs=max_epochs, pct_start=0.15, final_div_factor=20. ) # model = torch.compile(model) model = torch.compile(model) model, optimizer, train_dataloader, test_dataloader, scheduler = \ accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler) Loading Loading
zoo/ccip/attention_pool.py +84 −1 Original line number Diff line number Diff line import torch import torch.nn.functional as F from torch import nn from einops import repeat class AttentionPool2d(nn.Module): """ Loading Loading @@ -42,3 +42,86 @@ class AttentionPool2d(nn.Module): need_weights=False ) return x.squeeze(0) class AttentionPool2d_query(nn.Module): """ If the CNN's output is (1, 2048, 7, 7), then the parameters should be (7, 2048, 32, 1024), so if the CNN's output is (1, 512, 12, 12), then it should be (12, 512, 32?, 1024). """ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, n_query=8): super().__init__() self.query_emb = nn.Parameter(torch.randn(n_query, 1, embed_dim) / embed_dim ** 0.5) self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC q = torch.cat((x[:1], repeat(self.query_emb, 'hw 1 c -> hw n c', n=x.shape[1])), dim=0) x, _ = F.multi_head_attention_forward( query=q, key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False ) # [N_q+1, N, C] return x class AttentionPool2d_flat(nn.Module): """ If the CNN's output is (1, 2048, 7, 7), then the parameters should be (7, 2048, 32, 1024), so if the CNN's output is (1, 512, 12, 12), then it should be (12, 512, 32?, 1024). """ def __init__(self, n_token: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() self.positional_embedding = nn.Parameter(torch.randn(n_token + 1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x[:1], key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False ) return x.squeeze(0) No newline at end of file
zoo/ccip/backbone.py +2 −1 Original line number Diff line number Diff line Loading @@ -5,7 +5,7 @@ import clip import torch from torchvision.transforms import Compose from .caformer import get_caformer, get_caformer_s18 from .caformer import get_caformer, get_caformer_s18, get_caformer_query def get_clip_backbone(name="ViT-B/32") -> Tuple[torch.nn.Module, Compose]: Loading @@ -22,6 +22,7 @@ def register_backbone(name, func, *args, **kwargs): register_backbone('caformer', get_caformer) register_backbone('caformer_query', get_caformer_query) register_backbone('caformer_s18', get_caformer_s18) Loading
zoo/ccip/caformer.py +17 −4 Original line number Diff line number Diff line import torch.nn from torchvision.transforms import Normalize from .attention_pool import AttentionPool2d from .attention_pool import AttentionPool2d, AttentionPool2d_query, AttentionPool2d_flat from ..monochrome.metaformer import CAFormerBuilder from torch import nn class CaformerBackbone(torch.nn.Module): def __init__(self, input_resolution: int = 384, heads: int = 8, out_dims: int = 768, **kwargs): def __init__(self, input_resolution: int = 384, heads: int = 8, out_dims: int = 768, pool_with_query=False, **kwargs): torch.nn.Module.__init__(self) self.input_resolution = input_resolution self.caformer = CAFormerBuilder(**kwargs)() if pool_with_query: self.attnpool = nn.Sequential( AttentionPool2d_query(self.input_resolution // 32, self.caformer.output_dim, heads, out_dims, n_query=8), AttentionPool2d_flat(8, out_dims, heads, out_dims), ) else: self.attnpool = AttentionPool2d(self.input_resolution//32, self.caformer.output_dim, heads, out_dims) def _get_cnn_result(self, x): Loading @@ -33,6 +39,13 @@ def get_caformer(input_resolution: int = 384, heads: int = 8, feat_dims: int = 7 return CaformerBackbone(input_resolution, heads, feat_dims, **kwargs), transform def get_caformer_query(input_resolution: int = 384, heads: int = 8, feat_dims: int = 768, **kwargs): transform = [ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ] return CaformerBackbone(input_resolution, heads, feat_dims, pool_with_query=True, **kwargs), transform def get_caformer_s18(input_resolution: int = 384, heads: int = 8, feat_dims: int = 768, **kwargs): transform = [ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), Loading
zoo/ccip/dataset.py +44 −30 Original line number Diff line number Diff line Loading @@ -29,10 +29,28 @@ class WeakRandAugment(transforms.RandAugment): "Equalize":(torch.tensor(0.0), False), } class WeakRandAugment2(transforms.RandAugment): def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[torch.Tensor, bool]]: return { # op_name: (magnitudes, signed) "Identity":(torch.tensor(0.0), False), "ShearX":(torch.linspace(0.0, 0.08, num_bins), True), "ShearY":(torch.linspace(0.0, 0.08, num_bins), True), "TranslateX":(torch.linspace(0.0, 0.08*image_size[1], num_bins), True), "TranslateY":(torch.linspace(0.0, 0.08*image_size[0], num_bins), True), "Rotate":(torch.linspace(0.0, 8.0, num_bins), True), "Brightness":(torch.linspace(0.0, 0.05, num_bins), True), "Contrast":(torch.linspace(0.0, 0.05, num_bins), True), "Sharpness":(torch.linspace(0.0, 0.2, num_bins), True), "Posterize":(8-(torch.arange(num_bins)/((num_bins-1)/4)).round().int(), False), "AutoContrast":(torch.tensor(0.0), False), "Equalize":(torch.tensor(0.0), False), } TRAIN_TRANSFORM = [ transforms.Resize((416, 416)), transforms.RandomHorizontalFlip(), WeakRandAugment(), WeakRandAugment2(), transforms.RandomCrop(384), transforms.ToTensor(), ] Loading @@ -42,7 +60,6 @@ TEST_TRANSFORM = [ transforms.ToTensor(), ] class ImagesDataset(Dataset): def __init__(self, items: List[Tuple[str, int]], transform=None): self.items: List[Tuple[str, int]] = items Loading Loading @@ -73,7 +90,6 @@ class ImagesDataset(Dataset): return ImagesDataset(train_items, train_transform or self.transform), \ ImagesDataset(test_items, test_transform or self.transform) class CCIPImagesDataset(ImagesDataset): def __init__(self, root_dir, transform=None): _maxid = 0 Loading @@ -88,7 +104,6 @@ class CCIPImagesDataset(ImagesDataset): super(CCIPImagesDataset, self).__init__(_items, transform) class CharacterDataset(Dataset): def __init__(self, images_dataset: ImagesDataset, group_size: int = 100, prob: float = 0.5, force_prob: bool = True): Loading Loading @@ -140,7 +155,6 @@ class CharacterDataset(Dataset): return torch.stack(list(map(torch.as_tensor, images))), \ torch.stack(list(map(torch.as_tensor, labels))) class FastCharacterDataset(Dataset): def __init__(self, images_dataset: ImagesDataset, group_size: int = 100, prob: float = 0.5, **kwargs): Loading @@ -162,7 +176,8 @@ class FastCharacterDataset(Dataset): def reset(self): idxs = np.arange(0, len(self.images_dataset.items)) np.random.shuffle(idxs) self.idxs = idxs[:-(len(idxs) % self.group_size)] rest = len(idxs)%self.group_size self.idxs = idxs[:-rest] if rest>0 else idxs def __len__(self): return len(self.idxs) Loading @@ -182,7 +197,6 @@ class FastCharacterDataset(Dataset): return image, cid def char_collect_fn(batch): img_list, cid_list = [], [] for data in batch: Loading
zoo/ccip/train_.py +1 −1 Original line number Diff line number Diff line Loading @@ -145,7 +145,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio steps_per_epoch=len(train_dataloader)//accelerator.num_processes, epochs=max_epochs, pct_start=0.15, final_div_factor=20. ) # model = torch.compile(model) model = torch.compile(model) model, optimizer, train_dataloader, test_dataloader, scheduler = \ accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler) Loading