Loading zoo/ccip/model.py +19 −3 Original line number Diff line number Diff line Loading @@ -20,16 +20,32 @@ class DiffMethod(nn.Module): return x class CCIP(torch.nn.Module): class CCIPFeature(torch.nn.Module): def __init__(self, name: str = "clip/ViT-B/32"): torch.nn.Module.__init__(self) self.backbone, self.preprocess = get_backbone(name) def forward(self, x): x = self.backbone(x) x = x / x.norm(dim=-1, keepdim=True) return x class CCIP(torch.nn.Module): def __init__(self, name: str = "clip/ViT-B/32"): torch.nn.Module.__init__(self) self.backbone = CCIPFeature(name) self.diff = DiffMethod() self.cos_sim = torch.nn.CosineSimilarity(dim=-1) @property def preprocess(self): return self.backbone.preprocess def forward(self, x, y): x = self.backbone(x) y = self.backbone(y) dis = F.pairwise_distance(x, y, keepdim=True) dis = self.cos_sim(x, y) return self.diff(dis) Loading @@ -44,4 +60,4 @@ if __name__ == '__main__': print(d1.shape, d1.dtype) print(d2.shape, d2.dtype) print(F.softmax(model.forward(d1, d2))) print(F.softmax(model.forward(d1, d2), dim=-1)) zoo/ccip/train_.py +9 −14 Original line number Diff line number Diff line Loading @@ -3,7 +3,6 @@ import re from typing import Optional import torch import torch.nn.functional as F from accelerate import Accelerator from ditk import logging from hbutils.random import global_seed Loading @@ -12,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import Compose from tqdm.auto import tqdm from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, TEST_TRANSFORM from .loss import FocalLoss from .model import CCIP from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR Loading Loading @@ -91,14 +90,10 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio writer = None model = CCIP(model_name) preprocess = model.preprocess train_transform = Compose([*TRAIN_TRANSFORM.transforms, *preprocess.transforms]) test_transform = preprocess image_dataset = CCIPImagesDataset(dataset_dir) train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio) train_image_dataset.transform = train_transform test_image_dataset.transform = test_transform train_image_dataset.transform = Compose([*TRAIN_TRANSFORM.transforms, *model.preprocess.transforms]) test_image_dataset.transform = Compose([*TEST_TRANSFORM.transforms, *model.preprocess.transforms]) train_dataset = CharacterDataset(train_image_dataset, group_size) test_dataset = CharacterDataset(test_image_dataset, group_size) Loading Loading @@ -142,11 +137,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio m_labels = (char_ids == char_ids.reshape(-1, 1)) # BxB features = model.backbone(inputs) # BxF m_dists = F.pairwise_distance(features, features.unsqueeze(1)) # BxB dists = m_dists[mask].to(accelerator.device) m_sims = torch.nn.CosineSimilarity(dim=-1)(features, features.unsqueeze(1)) # BxB sims = m_sims[mask].to(accelerator.device) labels = m_labels[mask].type(torch.long).to(accelerator.device) outputs = model.diff(dists.reshape(-1, 1)) outputs = model.diff(sims.reshape(-1, 1)) preds = torch.argmax(outputs, dim=1) train_correct += (preds == labels).sum().item() train_fp += (preds[labels == 0] == 1).sum().item() Loading Loading @@ -190,11 +185,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio m_labels = (char_ids == char_ids.reshape(-1, 1)) # BxB features = model.backbone(inputs) # BxF m_dists = F.pairwise_distance(features, features.unsqueeze(1)) # BxB dists = m_dists[mask].to(accelerator.device) m_sims = torch.nn.CosineSimilarity(dim=-1)(features, features.unsqueeze(1)) # BxB sims = m_sims[mask].to(accelerator.device) labels = m_labels[mask].type(torch.long).to(accelerator.device) outputs = model.diff(dists.reshape(-1, 1)) outputs = model.diff(sims.reshape(-1, 1)) preds = torch.argmax(outputs, dim=1) test_correct += (preds == labels).sum().item() test_fp += (preds[labels == 0] == 1).sum().item() Loading Loading
zoo/ccip/model.py +19 −3 Original line number Diff line number Diff line Loading @@ -20,16 +20,32 @@ class DiffMethod(nn.Module): return x class CCIP(torch.nn.Module): class CCIPFeature(torch.nn.Module): def __init__(self, name: str = "clip/ViT-B/32"): torch.nn.Module.__init__(self) self.backbone, self.preprocess = get_backbone(name) def forward(self, x): x = self.backbone(x) x = x / x.norm(dim=-1, keepdim=True) return x class CCIP(torch.nn.Module): def __init__(self, name: str = "clip/ViT-B/32"): torch.nn.Module.__init__(self) self.backbone = CCIPFeature(name) self.diff = DiffMethod() self.cos_sim = torch.nn.CosineSimilarity(dim=-1) @property def preprocess(self): return self.backbone.preprocess def forward(self, x, y): x = self.backbone(x) y = self.backbone(y) dis = F.pairwise_distance(x, y, keepdim=True) dis = self.cos_sim(x, y) return self.diff(dis) Loading @@ -44,4 +60,4 @@ if __name__ == '__main__': print(d1.shape, d1.dtype) print(d2.shape, d2.dtype) print(F.softmax(model.forward(d1, d2))) print(F.softmax(model.forward(d1, d2), dim=-1))
zoo/ccip/train_.py +9 −14 Original line number Diff line number Diff line Loading @@ -3,7 +3,6 @@ import re from typing import Optional import torch import torch.nn.functional as F from accelerate import Accelerator from ditk import logging from hbutils.random import global_seed Loading @@ -12,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter from torchvision.transforms import Compose from tqdm.auto import tqdm from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, TEST_TRANSFORM from .loss import FocalLoss from .model import CCIP from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR Loading Loading @@ -91,14 +90,10 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio writer = None model = CCIP(model_name) preprocess = model.preprocess train_transform = Compose([*TRAIN_TRANSFORM.transforms, *preprocess.transforms]) test_transform = preprocess image_dataset = CCIPImagesDataset(dataset_dir) train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio) train_image_dataset.transform = train_transform test_image_dataset.transform = test_transform train_image_dataset.transform = Compose([*TRAIN_TRANSFORM.transforms, *model.preprocess.transforms]) test_image_dataset.transform = Compose([*TEST_TRANSFORM.transforms, *model.preprocess.transforms]) train_dataset = CharacterDataset(train_image_dataset, group_size) test_dataset = CharacterDataset(test_image_dataset, group_size) Loading Loading @@ -142,11 +137,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio m_labels = (char_ids == char_ids.reshape(-1, 1)) # BxB features = model.backbone(inputs) # BxF m_dists = F.pairwise_distance(features, features.unsqueeze(1)) # BxB dists = m_dists[mask].to(accelerator.device) m_sims = torch.nn.CosineSimilarity(dim=-1)(features, features.unsqueeze(1)) # BxB sims = m_sims[mask].to(accelerator.device) labels = m_labels[mask].type(torch.long).to(accelerator.device) outputs = model.diff(dists.reshape(-1, 1)) outputs = model.diff(sims.reshape(-1, 1)) preds = torch.argmax(outputs, dim=1) train_correct += (preds == labels).sum().item() train_fp += (preds[labels == 0] == 1).sum().item() Loading Loading @@ -190,11 +185,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio m_labels = (char_ids == char_ids.reshape(-1, 1)) # BxB features = model.backbone(inputs) # BxF m_dists = F.pairwise_distance(features, features.unsqueeze(1)) # BxB dists = m_dists[mask].to(accelerator.device) m_sims = torch.nn.CosineSimilarity(dim=-1)(features, features.unsqueeze(1)) # BxB sims = m_sims[mask].to(accelerator.device) labels = m_labels[mask].type(torch.long).to(accelerator.device) outputs = model.diff(dists.reshape(-1, 1)) outputs = model.diff(sims.reshape(-1, 1)) preds = torch.argmax(outputs, dim=1) test_correct += (preds == labels).sum().item() test_fp += (preds[labels == 0] == 1).sum().item() Loading