Commit cc498cea authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add basic classify client

parent 0dfc7e58
Loading
Loading
Loading
Loading
+35 −0
Original line number Diff line number Diff line
import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.validate import anime_classify
from imgutils.validate.classify import _MODEL_NAMES


class AnimeClassifyBenchmark(BaseBenchmark):
    def __init__(self, model):
        BaseBenchmark.__init__(self)
        self.model = model

    def load(self):
        from imgutils.validate.classify import _open_anime_classify_model
        _ = _open_anime_classify_model(self.model)

    def unload(self):
        from imgutils.validate.classify import _open_anime_classify_model
        _open_anime_classify_model.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = anime_classify(image_file, self.model)


if __name__ == '__main__':
    create_plot_cli(
        [
            (name, AnimeClassifyBenchmark(name))
            for name in _MODEL_NAMES
        ],
        title='Benchmark for Anime Classify Models',
        run_times=10,
        try_times=20,
    )()
+1 −0
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
Overview:
    Tools for image validation and classification, which can be used to filter datasets.
"""
from .classify import *
from .color import *
from .monochrome import *
from .truncate import *
+88 −0
Original line number Diff line number Diff line
from functools import lru_cache
from typing import Tuple, Optional, Dict, Union, Mapping

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from imgutils.data import rgb_encode, ImageTyping, load_image
from imgutils.utils import open_onnx_model

__all__ = [
    'anime_classify',
    'is_3d', 'is_bangumi', 'is_comic', 'is_illustration',
]

_MODEL_METAS = [
    ('mobilenetv3_large_100', 0.533, 0.438, 0.440, 0.446),
    ('mobilevitv2_150', 0.315, 0.354, 0.595, 0.511),
]
_LABELS = ['3d', 'bangumi', 'comic', 'illustration']

_MODEL_NAMES = [name for name, *_ in _MODEL_METAS]
_DEFAULT_MODEL_NAME = _MODEL_NAMES[0]
_MODEL_THRESHOLDS = {name: dict(zip(_LABELS, thresholds)) for name, *thresholds in _MODEL_METAS}


@lru_cache()
def _open_anime_classify_model(model_name):
    return open_onnx_model(hf_hub_download(
        f'deepghs/imgutils-models',
        f'anime_cls/anime_cls_{model_name}.onnx',
    ))


def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
                normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
    image = image.resize(size, Image.BILINEAR)
    data = rgb_encode(image, order_='CHW')

    if normalize is not None:
        mean_, std_ = normalize
        mean = np.asarray([mean_]).reshape((-1, 1, 1))
        std = np.asarray([std_]).reshape((-1, 1, 1))
        data = (data - mean) / std

    return data.astype(np.float32)


def _default_thresholds(model_name: str = _DEFAULT_MODEL_NAME) -> Mapping[str, float]:
    return _MODEL_THRESHOLDS[model_name]


def anime_classify(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME,
                   check: bool = False, thresholds: Optional[Mapping[str, float]] = None) \
        -> Dict[str, Union[float, bool]]:
    image = load_image(image, force_background='white', mode='RGB')
    input_ = _img_encode(image)[None, ...]
    output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_})
    values = dict(zip(_LABELS, map(lambda x: x.item(), output[0])))
    thresholds = thresholds or _default_thresholds(model_name)
    if check:
        return {label: values[label] >= thresholds[label] for label in _LABELS}
    else:
        return values


def _is_cls(image: ImageTyping, cls_name: str, model_name: str = _DEFAULT_MODEL_NAME, threshold: float = None):
    thresholds = dict(_default_thresholds(model_name))
    if thresholds is not None:
        thresholds[cls_name] = threshold

    return anime_classify(image, model_name, check=True, thresholds=thresholds)[cls_name]


def is_3d(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, threshold: float = None):
    return _is_cls(image, '3d', model_name, threshold)


def is_bangumi(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, threshold: float = None):
    return _is_cls(image, 'bangumi', model_name, threshold)


def is_comic(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, threshold: float = None):
    return _is_cls(image, 'comic', model_name, threshold)


def is_illustration(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, threshold: float = None):
    return _is_cls(image, 'illustration', model_name, threshold)