Commit 6ee99250 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): temp save

parent 6d4798d4
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -13,3 +13,6 @@ tensorboard
einops
thop
accelerate
ftfy
regex
git+https://github.com/openai/CLIP.git
 No newline at end of file

zoo/ccip/__init__.py

0 → 100644
+0 −0

Empty file added.

zoo/ccip/backbone.py

0 → 100644
+33 −0
Original line number Diff line number Diff line
from functools import partial
from typing import Tuple, Dict, Callable

import clip
import torch
from torchvision.transforms import Compose


def get_clip_backbone(name="ViT-B/32") -> Tuple[torch.nn.Module, Compose]:
    model, preprocess = clip.load(name, device='cpu')
    return model.visual.type(torch.float32), preprocess


CLIP_PREFIX = 'clip/'
_KNOWN_BACKBONES: Dict[str, Callable[..., Tuple[torch.nn.Module, Compose]]] = {}


def register_backbone(name, func, *args, **kwargs):
    _KNOWN_BACKBONES[name] = partial(func, *args, **kwargs)


def get_backbone(name: str) -> Tuple[torch.nn.Module, Compose]:
    if name.startswith(CLIP_PREFIX):
        clip_name = name[len(CLIP_PREFIX):]
        if clip_name in clip.available_models():
            return get_clip_backbone(clip_name)
        else:
            raise ValueError(f'Unknown model in clip - {clip_name!r}.')
    else:
        if name in _KNOWN_BACKBONES:
            return _KNOWN_BACKBONES[name]()
        else:
            raise ValueError(f'Unknown backbone - {name!r}.')

zoo/ccip/dataset.py

0 → 100644
+34 −0
Original line number Diff line number Diff line
import glob
import os.path
from typing import List, Tuple

from PIL import Image
from torch.utils.data import Dataset

from imgutils.data import load_image


class CCIPDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        _ids, _maxid = {}, 0
        self.items: List[Tuple[str, int]] = []
        self.transform = transform

        for file in glob.glob(os.path.join(root_dir, '*', '*', '*.jpg')):
            dirname = os.path.normcase(os.path.normpath(os.path.dirname(os.path.abspath(file))))
            if dirname not in _ids:
                _ids[dirname] = _maxid
                _maxid += 1

            self.items.append((file, _ids[dirname]))

    def __len__(self):
        return len(self.items)

    def __getitem__(self, index) -> Tuple[Image.Image, int]:
        filename, idx = self.items[index]
        image = load_image(filename, mode='RGB')
        if self.transform:
            image = self.transform(image)

        return image, idx

zoo/ccip/loss.py

0 → 100644
+26 −0
Original line number Diff line number Diff line
import torch
from torch import nn

from torch.nn import functional as F


class FocalLoss(nn.Module):
    """
    Based on https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/8
    """

    def __init__(self, weight=None, gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = torch.as_tensor(weight).float() if weight is not None else weight
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob,
            target_tensor,
            weight=self.weight,
            reduction=self.reduction
        )
Loading