Commit 30172f68 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): use cosine similarity

parent 3e1f569e
Loading
Loading
Loading
Loading
+19 −3
Original line number Diff line number Diff line
@@ -20,16 +20,32 @@ class DiffMethod(nn.Module):
        return x


class CCIP(torch.nn.Module):
class CCIPFeature(torch.nn.Module):
    def __init__(self, name: str = "clip/ViT-B/32"):
        torch.nn.Module.__init__(self)
        self.backbone, self.preprocess = get_backbone(name)

    def forward(self, x):
        x = self.backbone(x)
        x = x / x.norm(dim=-1, keepdim=True)
        return x


class CCIP(torch.nn.Module):
    def __init__(self, name: str = "clip/ViT-B/32"):
        torch.nn.Module.__init__(self)
        self.backbone = CCIPFeature(name)
        self.diff = DiffMethod()
        self.cos_sim = torch.nn.CosineSimilarity(dim=-1)

    @property
    def preprocess(self):
        return self.backbone.preprocess

    def forward(self, x, y):
        x = self.backbone(x)
        y = self.backbone(y)
        dis = F.pairwise_distance(x, y, keepdim=True)
        dis = self.cos_sim(x, y)
        return self.diff(dis)


@@ -44,4 +60,4 @@ if __name__ == '__main__':
    print(d1.shape, d1.dtype)
    print(d2.shape, d2.dtype)

    print(F.softmax(model.forward(d1, d2)))
    print(F.softmax(model.forward(d1, d2), dim=-1))
+9 −14
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@ import re
from typing import Optional

import torch
import torch.nn.functional as F
from accelerate import Accelerator
from ditk import logging
from hbutils.random import global_seed
@@ -12,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose
from tqdm.auto import tqdm

from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset
from .dataset import TRAIN_TRANSFORM, CCIPImagesDataset, CharacterDataset, TEST_TRANSFORM
from .loss import FocalLoss
from .model import CCIP
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR
@@ -91,14 +90,10 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        writer = None

    model = CCIP(model_name)
    preprocess = model.preprocess
    train_transform = Compose([*TRAIN_TRANSFORM.transforms, *preprocess.transforms])
    test_transform = preprocess

    image_dataset = CCIPImagesDataset(dataset_dir)
    train_image_dataset, test_image_dataset = image_dataset.split_dataset(test_prob=1 - train_ratio)
    train_image_dataset.transform = train_transform
    test_image_dataset.transform = test_transform
    train_image_dataset.transform = Compose([*TRAIN_TRANSFORM.transforms, *model.preprocess.transforms])
    test_image_dataset.transform = Compose([*TEST_TRANSFORM.transforms, *model.preprocess.transforms])

    train_dataset = CharacterDataset(train_image_dataset, group_size)
    test_dataset = CharacterDataset(test_image_dataset, group_size)
@@ -142,11 +137,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
            m_labels = (char_ids == char_ids.reshape(-1, 1))  # BxB

            features = model.backbone(inputs)  # BxF
            m_dists = F.pairwise_distance(features, features.unsqueeze(1))  # BxB
            dists = m_dists[mask].to(accelerator.device)
            m_sims = torch.nn.CosineSimilarity(dim=-1)(features, features.unsqueeze(1))  # BxB
            sims = m_sims[mask].to(accelerator.device)
            labels = m_labels[mask].type(torch.long).to(accelerator.device)

            outputs = model.diff(dists.reshape(-1, 1))
            outputs = model.diff(sims.reshape(-1, 1))
            preds = torch.argmax(outputs, dim=1)
            train_correct += (preds == labels).sum().item()
            train_fp += (preds[labels == 0] == 1).sum().item()
@@ -190,11 +185,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
                    m_labels = (char_ids == char_ids.reshape(-1, 1))  # BxB

                    features = model.backbone(inputs)  # BxF
                    m_dists = F.pairwise_distance(features, features.unsqueeze(1))  # BxB
                    dists = m_dists[mask].to(accelerator.device)
                    m_sims = torch.nn.CosineSimilarity(dim=-1)(features, features.unsqueeze(1))  # BxB
                    sims = m_sims[mask].to(accelerator.device)
                    labels = m_labels[mask].type(torch.long).to(accelerator.device)

                    outputs = model.diff(dists.reshape(-1, 1))
                    outputs = model.diff(sims.reshape(-1, 1))
                    preds = torch.argmax(outputs, dim=1)
                    test_correct += (preds == labels).sum().item()
                    test_fp += (preds[labels == 0] == 1).sum().item()