Commit 0363082a authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add documentation for anime_classify_score and anime_classify

parent 14f1fd41
Loading
Loading
Loading
Loading
+76 −3
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ from imgutils.data import rgb_encode, ImageTyping, load_image
from imgutils.utils import open_onnx_model

__all__ = [
    'anime_classify_scores',
    'anime_classify_score',
    'anime_classify',
]

@@ -71,14 +71,87 @@ def _raw_anime_classify(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAM
    return output


def anime_classify_scores(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) \
        -> Dict[str, float]:
def anime_classify_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]:
    """
    Overview:
        Predict the class of the given image, return the score with as a dict object.

    :param image: Image to classify.
    :param model_name: Model to use. Default is ``mobilenetv3_sce``. All available models are listed
        on the benchmark plot above. If you need better accuracy, just set this to ``caformer_s36_plus``.
    :return: A dict with classes and scores.

    Examples::
        >>> from imgutils.validate import anime_classify_score
        >>>
        >>> anime_classify_score('classify/3d/1.jpg')
        {'3d': 0.9999048709869385, 'bangumi': 1.2998967577004805e-05, 'comic': 6.3774782574910205e-06, 'illustration': 7.573143375338987e-05}
        >>> anime_classify_score('classify/3d/2.jpg')
        {'3d': 1.0, 'bangumi': 2.3347255118794097e-12, 'comic': 5.393629119720966e-12, 'illustration': 6.71077689945454e-12}
        >>> anime_classify_score('classify/3d/3.jpg')
        {'3d': 0.9999587535858154, 'bangumi': 1.6608031728537753e-05, 'comic': 1.4294577340479009e-05, 'illustration': 1.0324462891730946e-05}
        >>> anime_classify_score('classify/bangumi/4.jpg')
        {'3d': 8.245967464404202e-09, 'bangumi': 0.9999991655349731, 'comic': 2.004386701059957e-08, 'illustration': 8.202430876735889e-07}
        >>> anime_classify_score('classify/bangumi/5.jpg')
        {'3d': 6.440834113163874e-05, 'bangumi': 0.9982288479804993, 'comic': 1.4121969797997735e-05, 'illustration': 0.001692703110165894}
        >>> anime_classify_score('classify/bangumi/6.jpg')
        {'3d': 2.3443080159404883e-14, 'bangumi': 1.0, 'comic': 5.647845608075866e-14, 'illustration': 6.008537851293072e-13}
        >>> anime_classify_score('classify/comic/7.jpg')
        {'3d': 4.029740221408245e-18, 'bangumi': 4.658470278842451e-18, 'comic': 1.0, 'illustration': 2.0487814569869478e-11}
        >>> anime_classify_score('classify/comic/8.jpg')
        {'3d': 1.019530813939351e-11, 'bangumi': 1.5961519215720865e-12, 'comic': 1.0, 'illustration': 2.2395576712574972e-11}
        >>> anime_classify_score('classify/comic/9.jpg')
        {'3d': 2.1237236958165234e-13, 'bangumi': 2.3246717593440602e-14, 'comic': 1.0, 'illustration': 3.4230233231236085e-11}
        >>> anime_classify_score('classify/illustration/10.jpg')
        {'3d': 0.00026091927429661155, 'bangumi': 0.00011691388499457389, 'comic': 8.51359436637722e-05, 'illustration': 0.9995369911193848}
        >>> anime_classify_score('classify/illustration/11.jpg')
        {'3d': 6.014750475458186e-09, 'bangumi': 2.3536564697224094e-07, 'comic': 7.933858796604909e-06, 'illustration': 0.999991774559021}
        >>> anime_classify_score('classify/illustration/12.jpg')
        {'3d': 3.153582292725332e-05, 'bangumi': 0.0001071861624950543, 'comic': 5.665345452143811e-05, 'illustration': 0.999804675579071}
    """
    output = _raw_anime_classify(image, model_name)
    values = dict(zip(_LABELS, map(lambda x: x.item(), output[0])))
    return values


def anime_classify(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]:
    """
    Overview:
        Predict the class of the given image, return the class and its score.

    :param image: Image to classify.
    :param model_name: Model to use. Default is ``mobilenetv3_sce``. All available models are listed
        on the benchmark plot above. If you need better accuracy, just set this to ``caformer_s36_plus``.
    :return: A tuple contains the class and its score.

    Examples::
        >>> from imgutils.validate import anime_classify
        >>>
        >>> anime_classify('classify/3d/1.jpg')
        ('3d', 0.9999048709869385)
        >>> anime_classify('classify/3d/2.jpg')
        ('3d', 1.0)
        >>> anime_classify('classify/3d/3.jpg')
        ('3d', 0.9999587535858154)
        >>> anime_classify('classify/bangumi/4.jpg')
        ('bangumi', 0.9999991655349731)
        >>> anime_classify('classify/bangumi/5.jpg')
        ('bangumi', 0.9982288479804993)
        >>> anime_classify('classify/bangumi/6.jpg')
        ('bangumi', 1.0)
        >>> anime_classify('classify/comic/7.jpg')
        ('comic', 1.0)
        >>> anime_classify('classify/comic/8.jpg')
        ('comic', 1.0)
        >>> anime_classify('classify/comic/9.jpg')
        ('comic', 1.0)
        >>> anime_classify('classify/illustration/10.jpg')
        ('illustration', 0.9995369911193848)
        >>> anime_classify('classify/illustration/11.jpg')
        ('illustration', 0.999991774559021)
        >>> anime_classify('classify/illustration/12.jpg')
        ('illustration', 0.999804675579071)
    """
    output = _raw_anime_classify(image, model_name)[0]
    max_id = np.argmax(output)
    return _LABELS[max_id], output[max_id].item()
+7 −1
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ import os.path
import pytest

from imgutils.validate import anime_classify
from imgutils.validate.classify import _open_anime_classify_model
from imgutils.validate.classify import _open_anime_classify_model, anime_classify_score
from test.testings import get_testfile

_ROOT_DIR = get_testfile('anime_cls')
@@ -29,3 +29,9 @@ class TestValidateClassify:
        image_file = get_testfile('anime_cls', image)
        tag, score = anime_classify(image_file)
        assert tag == label

    @pytest.mark.parametrize(['image', 'label'], _EXAMPLE_FILES)
    def test_anime_classify_score(self, image, label):
        image_file = get_testfile('anime_cls', image)
        scores = anime_classify_score(image_file)
        assert scores[label] > 0.5