Commit 27adaabf authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update classify

parent 6cac3382
Loading
Loading
Loading
Loading
+6 −5
Original line number Diff line number Diff line
import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.generic.classify import _open_models_for_repo_id
from imgutils.validate import anime_classify
from imgutils.validate.classify import _MODEL_NAMES
from imgutils.validate.classify import _REPO_ID

_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names


class AnimeClassifyBenchmark(BaseBenchmark):
@@ -11,12 +14,10 @@ class AnimeClassifyBenchmark(BaseBenchmark):
        self.model = model

    def load(self):
        from imgutils.validate.classify import _open_anime_classify_model
        _ = _open_anime_classify_model(self.model)
        _open_models_for_repo_id(_REPO_ID)._open_model(self.model)

    def unload(self):
        from imgutils.validate.classify import _open_anime_classify_model
        _open_anime_classify_model.cache_clear()
        _open_models_for_repo_id(_REPO_ID).clear()

    def run(self):
        image_file = random.choice(self.all_images)
+0 −2782

File deleted.

Preview size limit exceeded, changes collapsed.

+6 −54
Original line number Diff line number Diff line
@@ -15,62 +15,18 @@ Overview:
    The models are hosted on
    `huggingface - deepghs/anime_classification <https://huggingface.co/deepghs/anime_classification>`_.
"""
from functools import lru_cache
from typing import Tuple, Optional, Dict
from typing import Tuple, Dict

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from imgutils.data import rgb_encode, ImageTyping, load_image
from imgutils.utils import open_onnx_model
from ..data import ImageTyping
from ..generic import classify_predict, classify_predict_score

__all__ = [
    'anime_classify_score',
    'anime_classify',
]

_LABELS = ['3d', 'bangumi', 'comic', 'illustration']
_MODEL_NAMES = [
    'caformer_s36',
    'caformer_s36_plus',
    'mobilenetv3',
    'mobilenetv3_dist',
    'mobilenetv3_sce',
    'mobilenetv3_sce_dist',
    'mobilevitv2_150',
]
_DEFAULT_MODEL_NAME = 'mobilenetv3_sce_dist'


@lru_cache()
def _open_anime_classify_model(model_name):
    return open_onnx_model(hf_hub_download(
        f'deepghs/anime_classification',
        f'{model_name}/model.onnx',
    ))


def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
                normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
    image = image.resize(size, Image.BILINEAR)
    data = rgb_encode(image, order_='CHW')

    if normalize is not None:
        mean_, std_ = normalize
        mean = np.asarray([mean_]).reshape((-1, 1, 1))
        std = np.asarray([std_]).reshape((-1, 1, 1))
        data = (data - mean) / std

    return data.astype(np.float32)


def _raw_anime_classify(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME):
    image = load_image(image, force_background='white', mode='RGB')
    input_ = _img_encode(image)[None, ...]
    output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_})

    return output
_REPO_ID = 'deepghs/anime_classification'


def anime_classify_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]:
@@ -111,9 +67,7 @@ def anime_classify_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NA
        >>> anime_classify_score('classify/illustration/12.jpg')
        {'3d': 3.153582292725332e-05, 'bangumi': 0.0001071861624950543, 'comic': 5.665345452143811e-05, 'illustration': 0.999804675579071}
    """
    output = _raw_anime_classify(image, model_name)
    values = dict(zip(_LABELS, map(lambda x: x.item(), output[0])))
    return values
    return classify_predict_score(image, _REPO_ID, model_name)


def anime_classify(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]:
@@ -154,6 +108,4 @@ def anime_classify(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) ->
        >>> anime_classify('classify/illustration/12.jpg')
        ('illustration', 0.999804675579071)
    """
    output = _raw_anime_classify(image, model_name)[0]
    max_id = np.argmax(output)
    return _LABELS[max_id], output[max_id].item()
    return classify_predict(image, _REPO_ID, model_name)
+3 −2
Original line number Diff line number Diff line
@@ -3,8 +3,9 @@ import os.path

import pytest

from imgutils.generic.classify import _open_models_for_repo_id
from imgutils.validate import anime_classify
from imgutils.validate.classify import _open_anime_classify_model, anime_classify_score
from imgutils.validate.classify import anime_classify_score, _REPO_ID
from test.testings import get_testfile

_ROOT_DIR = get_testfile('anime_cls')
@@ -19,7 +20,7 @@ def _release_model_after_run():
    try:
        yield
    finally:
        _open_anime_classify_model.cache_clear()
        _open_models_for_repo_id(_REPO_ID).clear()


@pytest.mark.unittest