Commit e4c98a5e authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add generic classify method

parent 8813aa0a
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
from .classify import *
+113 −0
Original line number Diff line number Diff line
import json
import os
from functools import lru_cache
from typing import Tuple, Optional, List, Dict

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download, HfFileSystem

from ..data import rgb_encode, ImageTyping, load_image
from ..utils import open_onnx_model

__all__ = [
    'ClassifyModel',
    'classify_predict_score',
    'classify_predict',
]


def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
                normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
    image = image.resize(size, Image.BILINEAR)
    data = rgb_encode(image, order_='CHW')

    if normalize is not None:
        mean_, std_ = normalize
        mean = np.asarray([mean_]).reshape((-1, 1, 1))
        std = np.asarray([std_]).reshape((-1, 1, 1))
        data = (data - mean) / std

    return data.astype(np.float32)


class ClassifyModel:
    def __init__(self, repo_id: str):
        self.repo_id = repo_id
        self._model_names = None
        self._models = {}
        self._labels = {}

    @classmethod
    def _get_hf_token(cls):
        return os.environ.get('HF_TOKEN')

    @property
    def model_names(self) -> List[str]:
        if self._model_names is None:
            hf_fs = HfFileSystem(token=self._get_hf_token())
            self._model_names = [
                os.path.dirname(os.path.relpath(item, self.repo_id)) for item in
                hf_fs.glob(f'{self.repo_id}/*/model.onnx')
            ]

        return self._model_names

    def _check_model_name(self, model_name: str):
        if model_name not in self.model_names:
            raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
                             f'models {self.model_names!r} are available.')

    def _open_model(self, model_name: str):
        if model_name not in self._models:
            self._check_model_name(model_name)
            self._models[model_name] = open_onnx_model(hf_hub_download(
                self.repo_id,
                f'{model_name}/model.onnx',
                token=self._get_hf_token(),
            ))
        return self._models[model_name]

    def _open_label(self, model_name: str) -> List[str]:
        if model_name not in self._labels:
            self._check_model_name(model_name)
            with open(hf_hub_download(
                    self.repo_id,
                    f'{model_name}/meta.json',
                    token=self._get_hf_token(),
            ), 'r') as f:
                self._labels[model_name] = json.load(f)['labels']
        return self._labels[model_name]

    def _raw_predict(self, image: ImageTyping, model_name: str):
        image = load_image(image, force_background='white', mode='RGB')
        input_ = _img_encode(image)[None, ...]
        output, = self._open_model(model_name).run(['output'], {'input': input_})
        return output

    def predict_score(self, image: ImageTyping, model_name: str) -> Dict[str, float]:
        output = self._raw_predict(image, model_name)
        values = dict(zip(self._open_label(model_name), map(lambda x: x.item(), output[0])))
        return values

    def predict(self, image: ImageTyping, model_name: str) -> Tuple[str, float]:
        output = self._raw_predict(image, model_name)[0]
        max_id = np.argmax(output)
        return self._open_label(model_name)[max_id], output[max_id].item()

    def clear(self):
        self._models.clear()
        self._labels.clear()


@lru_cache()
def _open_models_for_repo_id(repo_id: str) -> ClassifyModel:
    return ClassifyModel(repo_id)


def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str) -> Dict[str, float]:
    return _open_models_for_repo_id(repo_id).predict_score(image, model_name)


def classify_predict(image: ImageTyping, repo_id: str, model_name: str) -> Tuple[str, float]:
    return _open_models_for_repo_id(repo_id).predict(image, model_name)
+6 −90
Original line number Diff line number Diff line
@@ -15,16 +15,10 @@ Overview:
    The models are hosted on
    `huggingface - deepghs/anime_real_cls <https://huggingface.co/deepghs/anime_real_cls>`_.
"""
import json
from functools import lru_cache
from typing import Tuple, Optional, Dict, List
from typing import Tuple, Dict

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from imgutils.data import rgb_encode, ImageTyping, load_image
from imgutils.utils import open_onnx_model
from ..data import ImageTyping
from ..generic import classify_predict, classify_predict_score

__all__ = [
    'anime_real_score',
@@ -32,81 +26,7 @@ __all__ = [
]

_DEFAULT_MODEL_NAME = 'mobilenetv3_v0_dist'


@lru_cache()
def _open_anime_real_model(model_name):
    """
    Open the anime real model.

    :param model_name: The model name.
    :type model_name: str
    :return: The ONNX model.
    """
    return open_onnx_model(hf_hub_download(
        f'deepghs/anime_real_cls',
        f'{model_name}/model.onnx',
    ))


@lru_cache()
def _get_anime_real_labels(model_name) -> List[str]:
    """
    Get the labels for the anime real model.

    :param model_name: The model name.
    :type model_name: str
    :return: The list of labels.
    :rtype: List[str]
    """
    with open(hf_hub_download(
            f'deepghs/anime_real_cls',
            f'{model_name}/meta.json',
    ), 'r') as f:
        return json.load(f)['labels']


def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
                normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
    """
    Encode the input image.

    :param image: The input image.
    :type image: Image.Image
    :param size: The desired size of the image.
    :type size: Tuple[int, int]
    :param normalize: Mean and standard deviation for normalization. Default is (0.5, 0.5).
    :type normalize: Optional[Tuple[float, float]]
    :return: The encoded image data.
    :rtype: np.ndarray
    """
    image = image.resize(size, Image.BILINEAR)
    data = rgb_encode(image, order_='CHW')

    if normalize is not None:
        mean_, std_ = normalize
        mean = np.asarray([mean_]).reshape((-1, 1, 1))
        std = np.asarray([std_]).reshape((-1, 1, 1))
        data = (data - mean) / std

    return data.astype(np.float32)


def _raw_anime_real(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME):
    """
    Perform raw anime real processing on the input image.

    :param image: The input image.
    :type image: ImageTyping
    :param model_name: The model name. Default is 'mobilenetv3_v0_dist'.
    :type model_name: str
    :return: The processed image data.
    :rtype: np.ndarray
    """
    image = load_image(image, force_background='white', mode='RGB')
    input_ = _img_encode(image)[None, ...]
    output, = _open_anime_real_model(model_name).run(['output'], {'input': input_})
    return output
_REPO_ID = 'deepghs/anime_real_cls'


def anime_real_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]:
@@ -156,9 +76,7 @@ def anime_real_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME)
        >>> anime_real_score('real/real/16.jpg')
        {'anime': 1.5513256585109048e-05, 'real': 0.9999845027923584}
    """
    output = _raw_anime_real(image, model_name)
    values = dict(zip(_get_anime_real_labels(model_name), map(lambda x: x.item(), output[0])))
    return values
    return classify_predict_score(image, _REPO_ID, model_name)


def anime_real(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]:
@@ -208,6 +126,4 @@ def anime_real(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tup
        >>> anime_real('real/real/16.jpg')
        ('real', 0.9999845027923584)
    """
    output = _raw_anime_real(image, model_name)[0]
    max_id = np.argmax(output)
    return _get_anime_real_labels(model_name)[max_id], output[max_id].item()
    return classify_predict(image, _REPO_ID, model_name)