Commit 22053872 authored by dzy7e's avatar dzy7e
Browse files

ccip clip

parent 3f023015
Loading
Loading
Loading
Loading
+3 −7
Original line number Diff line number Diff line
@@ -26,13 +26,9 @@ class CaformerBackbone(torch.nn.Module):
        return x


def get_caformer(input_resolution: int = 224, heads: int = 32, feat_dims: int = 1024, **kwargs):
    transform = Compose([
        Resize(input_resolution, interpolation=InterpolationMode.BICUBIC),
        CenterCrop(input_resolution),
        lambda x: x.convert('RGB'),
        ToTensor(),
def get_caformer(input_resolution: int = 384, heads: int = 32, feat_dims: int = 1024, **kwargs):
    transform = [
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])
    ]

    return CaformerBackbone(input_resolution, heads, feat_dims, **kwargs), transform
+59 −7
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ import os.path
import random
from typing import List, Tuple, Dict

import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
@@ -11,14 +12,19 @@ from torchvision import transforms
from imgutils.data import load_image
from .prob import get_reg_for_prob

TRAIN_TRANSFORM = transforms.Compose([
    transforms.Resize(300),
    transforms.RandomCrop(250, padding=25, pad_if_needed=True, padding_mode='reflect'),
TRAIN_TRANSFORM = [
    transforms.Resize(416),
    transforms.RandomRotation((-15, 15)),
    transforms.RandomCrop(384),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.10, 0.10, 0.10, 0.10),
])
TEST_TRANSFORM = transforms.Compose([])
    transforms.ToTensor(),
]
TEST_TRANSFORM = [
    transforms.Resize(416),
    transforms.CenterCrop(384),
    transforms.ToTensor(),
]


class ImagesDataset(Dataset):
@@ -37,7 +43,7 @@ class ImagesDataset(Dataset):

        return image, idx

    def split_dataset(self, test_prob: float = 0.2):
    def split_dataset(self, test_prob: float = 0.2, train_transform=None, test_transform=None):
        total = len(self.items)
        test_ids = set(random.sample(list(range(total)), k=int(total * test_prob)))

@@ -48,7 +54,7 @@ class ImagesDataset(Dataset):
            else:
                train_items.append(item)

        return ImagesDataset(train_items, self.transform), ImagesDataset(test_items, self.transform)
        return ImagesDataset(train_items, train_transform or self.transform), ImagesDataset(test_items, test_transform or self.transform)


class CCIPImagesDataset(ImagesDataset):
@@ -116,3 +122,49 @@ class CharacterDataset(Dataset):

        return torch.stack(list(map(torch.as_tensor, images))), \
               torch.stack(list(map(torch.as_tensor, labels)))


class FastCharacterDataset(Dataset):
    def __init__(self, images_dataset: ImagesDataset, group_size: int = 100,
                 prob: float = 0.5, **kwargs):
        self.images_dataset = images_dataset

        groups: Dict[int, List[int]] = {}
        for i, (_, cid) in enumerate(self.images_dataset.items):
            if cid not in groups:
                groups[cid] = []
            groups[cid].append(i)
        self.groups = {k:np.array(v) for k,v in groups.items()}

        self.group_size = group_size
        self.prob = prob*2

    def reset(self):
        idxs = np.arange(0, len(self.images_dataset.items))
        np.random.shuffle(idxs)
        self.idxs = idxs[:-(len(idxs)%self.group_size)]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, cid = self.images_dataset[self.idxs[item]]
        n_same = int(self.prob) + int((self.prob-int(self.prob))<=(item%self.group_size/self.group_size))

        image = [image]
        cid = [cid]
        if n_same>0:
            same_idxs = random.sample(self.groups[cid], k=n_same)
            for idx in same_idxs:
                img_i, cid_i = self.images_dataset[idx]
                image.append(img_i)
                cid.append(cid_i)

        return image, cid

def char_collect_fn(batch):
    img_list, cid_list = [], []
    for data in batch:
        img_list.extend(data[0])
        cid_list.extend(data[1])
    return torch.stack(img_list), torch.tensor(cid_list)
 No newline at end of file
+22 −0
Original line number Diff line number Diff line
@@ -52,3 +52,25 @@ class NTXentLoss(nn.Module):

        pos_tensor = torch.stack(pos_items)
        return (pos_tensor.sum() + self.eps) / (pos_tensor.shape[0] + self.eps)

class MLCELoss(nn.Module):
    def __init__(self, weight=None, reduction='mean', eps=1e-4):
        super().__init__()
        weight = torch.as_tensor(weight).float() if weight is not None else weight
        self.register_buffer('weight', weight)
        self.reduction = reduction
        self.eps = eps

    def forward(self, input_tensor, target_tensor):
        log_prob_raw = F.softmax(input_tensor, dim=1)

        same_mask = (target_tensor.unsqueeze(0) == target_tensor.unsqueeze(1)).long() # [B,B]
        same_mask_diag0 = same_mask - torch.diag_embed(torch.diag(same_mask)) # diag=0

        log_prob_x = log_prob_raw.clone()
        log_prob_x[same_mask_diag0.bool()] = self.eps
        log_prob_x.diagonal().copy_((log_prob_raw*same_mask_diag0).sum(dim=1))
        log_prob_x = log_prob_x + torch.diag_embed(torch.ones(len(target_tensor))*self.eps)
        y = torch.arange(0, len(target_tensor))

        return F.nll_loss(log_prob_x.log(), y, weight=self.weight, reduction=self.reduction)
 No newline at end of file
+34 −19
Original line number Diff line number Diff line
@@ -2,19 +2,29 @@ import torch.nn
import torch.nn.functional as F
from PIL import Image
from torch import nn
import numpy as np

from zoo.utils import get_testfile
#from zoo.utils import get_testfile
from .backbone import get_backbone


class CCIPBatchMetrics(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.sim = nn.CosineSimilarity(dim=-1)
        self.logit_scale = nn.Parameter(torch.ones([])*np.log(1/0.07))
        #self.sim = nn.CosineSimilarity(dim=-1)

    def forward(self, x):  # x: BxN
        x = self.sim(x, x.unsqueeze(1))  # BxB
        return x
    def forward(self, image_features):  # x: BxN

        # normalized features
        image_features = image_features/image_features.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale*image_features@image_features.t()
        logits_per_image = logits_per_image - torch.diag_embed(torch.diag(logits_per_image))

        return logits_per_image


class CCIPFeature(torch.nn.Module):
@@ -46,17 +56,22 @@ class CCIP(torch.nn.Module):


if __name__ == '__main__':
    image_files = [
        get_testfile('6124220.jpg'),
        get_testfile('6125785.jpg'),
        get_testfile('6125901.jpg'),
    ]

    model = CCIP()
    data = torch.stack([
        model.preprocess(Image.open(img))
        for img in image_files
    ])
    print(data.dtype, data.shape)

    print(F.softmax(model.forward(data), dim=-1))
    # image_files = [
    #     get_testfile('6124220.jpg'),
    #     get_testfile('6125785.jpg'),
    #     get_testfile('6125901.jpg'),
    # ]
    #
    # model = CCIP()
    # data = torch.stack([
    #     model.preprocess(Image.open(img))
    #     for img in image_files
    # ])
    # print(data.dtype, data.shape)
    #
    # print(F.softmax(model.forward(data), dim=-1))

    data = torch.randn(4,3,384,384).cuda()
    model = CCIP('caformer').cuda()
    print(model(data))
+52 −32
Original line number Diff line number Diff line
@@ -12,10 +12,11 @@ from sklearn.metrics import accuracy_score
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, TEST_TRANSFORM
from .loss import NTXentLoss
from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, TEST_TRANSFORM, char_collect_fn
from .loss import NTXentLoss, MLCELoss
from .model import CCIP
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR

@@ -81,8 +82,8 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000):

def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 30,
          learning_rate: float = 0.001, weight_decay: float = 1e-3, tau: float = 0.15,
          save_per_epoch: int = 10, eval_epoch: int = 5,
          learning_rate: float = 0.001, weight_decay: float = 1e-2, tau: float = 0.15,
          save_per_epoch: int = 10, eval_epoch: int = 5, num_workers=8,
          model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0):
    if seed is not None:
        # native random, numpy, torch and faker's seeds are includes
@@ -113,12 +114,15 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

    model = CCIP(model_name)
    image_dataset = CCIPImagesDataset(dataset_dir)
    train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio)
    train_image_dataset.transform = Compose([*TRAIN_TRANSFORM.transforms, *model.preprocess.transforms])
    test_image_dataset.transform = Compose([*TEST_TRANSFORM.transforms, *model.preprocess.transforms])
    train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio,
                                    train_transform=Compose(TRAIN_TRANSFORM.transforms+model.preprocess.transforms),
                                    test_transform=Compose(TEST_TRANSFORM.transforms+model.preprocess.transforms),)

    train_dataset = CharacterDataset(train_image_dataset, group_size, force_prob=False)
    test_dataset = CharacterDataset(test_image_dataset, group_size)
    train_dataloader = DataLoader(train_dataset, batch_size=group_size, shuffle=True, num_workers=num_workers, collate_fn=char_collect_fn,
                                  drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=group_size, num_workers=num_workers, collate_fn=char_collect_fn)

    if from_ckpt is None:
        from_ckpt = _find_latest_ckpt(session_name)
@@ -129,7 +133,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    else:
        logging.info(f'No checkpoint found, new model will be used.')

    loss_fn = NTXentLoss(tau=tau).to(accelerator.device)
    loss_fn = MLCELoss().to(accelerator.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate,
@@ -137,36 +141,52 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        pct_start=0.15, final_div_factor=20.
    )

    model, optimizer, train_dataset, test_dataset, scheduler = \
        accelerator.prepare(model, optimizer, train_dataset, test_dataset, scheduler)
    model, optimizer, train_dataloader, test_dataloader, scheduler = \
        accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler)
    test_dataloader.dataset.reset()

    for epoch in range(previous_epoch + 1, max_epochs + 1):
        running_loss = 0.0
        train_pos_total = 0
        positive_sims, negative_sims = [], []
        model.train()
        for i, (inputs, char_ids) in enumerate(tqdm(train_dataset)):
            inputs = inputs.float()
        for i, (inputs, char_ids) in enumerate(tqdm(train_dataloader)):
            train_dataloader.dataset.reset()
            inputs = inputs.to(accelerator.device)  # BxCxHxW
            char_ids = char_ids.to(accelerator.device)  # B

            ix = torch.arange(0, char_ids.shape[0])
            mask = ix > ix.reshape(-1, 1)  # BxB, remove duplicated
            similarities = model(inputs)  # BxB
            outputs = similarities[mask]  # N
            labels = (char_ids == char_ids.reshape(-1, 1))[mask].to(accelerator.device)  # N
            # B = len(char_ids)
            # mask = torch.triu(torch.ones(B,B),diagonal=1).to(accelerator.device)  # BxB, remove duplicated
            # similarities = model(inputs)  # BxB
            # outputs = similarities[mask]  # N
            # labels = (char_ids.view(-1,1) == char_ids.view(1,-1))[mask]  # N
            # labels = char_ids

            outputs = model(inputs)  # BxB
            labels = char_ids

            loss = loss_fn(outputs, labels)
            accelerator.backward(loss)
            optimizer.step()
            train_pos_total += labels.sum()
            running_loss += loss.item() * labels.sum()
            scheduler.step()

            positive_sims.append(outputs[labels])
            negative_sims.append(outputs[~labels])

        epoch_loss = running_loss / train_pos_total
            running_loss += loss.item()*len(char_ids)

            gt = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu()
            outputs = outputs.detach().cpu()
            gt_diag0 = gt.clone()
            gt_diag0.diagonal().copy_(torch.zeros(len(char_ids)))
            # outputs.diagonal().copy_(torch.ones(len(char_ids))*-10000)
            # max_idxs = outputs.argsort(dim=-1)
            # for max_idx, n_pos in zip(max_idxs, gt.sum(dim=1)):
            #     train_pos_total += n_pos
            #     positive_sims.append(outputs[labels])
            #     negative_sims.append(outputs[~labels])
            train_pos_total += gt_diag0.sum()
            positive_sims.append(outputs[gt_diag0])
            negative_sims.append(outputs[~gt])

        epoch_loss = running_loss #/ train_pos_total
        train_psims = torch.cat(positive_sims)
        train_nsims = torch.cat(negative_sims)
        train_pos_mean, train_pos_std, train_neg_mean, train_neg_std, train_threshold, train_acc_svm = \
@@ -188,19 +208,19 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        if epoch % eval_epoch == 0:
            with torch.no_grad():
                positive_sims, negative_sims = [], []
                for i, (inputs, char_ids) in enumerate(tqdm(test_dataset)):
                    inputs = inputs.float()
                for i, (inputs, char_ids) in enumerate(tqdm(test_dataloader)):
                    inputs = inputs.to(accelerator.device)  # BxCxHxW
                    char_ids = char_ids.to(accelerator.device)  # B

                    ix = torch.arange(0, char_ids.shape[0])
                    mask = ix > ix.reshape(-1, 1)  # BxB, remove duplicated
                    similarities = model(inputs)  # BxB
                    outputs = similarities[mask]  # N
                    labels = (char_ids == char_ids.reshape(-1, 1))[mask].to(accelerator.device)  # N
                    outputs = model(inputs)  # BxB

                    positive_sims.append(outputs[labels])
                    negative_sims.append(outputs[~labels])
                    gt = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu()
                    outputs = outputs.detach().cpu()
                    gt_diag0 = gt.clone()
                    gt_diag0.diagonal().copy_(torch.zeros(len(char_ids)))
                    train_pos_total += gt_diag0.sum()
                    positive_sims.append(outputs[gt_diag0])
                    negative_sims.append(outputs[~gt])

                test_psims = torch.cat(positive_sims)
                test_nsims = torch.cat(negative_sims)