Commit d4dee18c authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update monochrome

parent 27adaabf
Loading
Loading
Loading
Loading
+10 −13
Original line number Diff line number Diff line
import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.generic.classify import _open_models_for_repo_id
from imgutils.validate import get_monochrome_score
from imgutils.validate.monochrome import _REPO_ID

_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names


class MonochromeBenchmark(BaseBenchmark):
    def __init__(self, model, safe):
    def __init__(self, model):
        BaseBenchmark.__init__(self)
        self.model = model
        self.safe = safe

    def load(self):
        from imgutils.validate.monochrome import _monochrome_validate_model
        _ = _monochrome_validate_model(self.model, self.safe)
        _open_models_for_repo_id(_REPO_ID)._open_model(self.model)

    def unload(self):
        from imgutils.validate.monochrome import _monochrome_validate_model
        _monochrome_validate_model.cache_clear()
        _open_models_for_repo_id(_REPO_ID).clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = get_monochrome_score(image_file, model=self.model, safe=self.safe)
        _ = get_monochrome_score(image_file, model_name=self.model, safe=self.safe)


if __name__ == '__main__':
    create_plot_cli(
        [
            ('caformer_s36 (unsafe)', MonochromeBenchmark('caformer_s36', False)),
            ('caformer_s36 (safe)', MonochromeBenchmark('caformer_s36', True)),
            ('mobilenetv3 (unsafe)', MonochromeBenchmark('mobilenetv3', False)),
            ('mobilenetv3 (safe)', MonochromeBenchmark('mobilenetv3', True)),
            ('mobilenetv3_dist (unsafe)', MonochromeBenchmark('mobilenetv3_dist', False)),
            ('mobilenetv3_dist (safe)', MonochromeBenchmark('mobilenetv3_dist', True)),
            (name, MonochromeBenchmark(name))
            for name in _MODEL_NAMES
        ],
        title='Benchmark for Monochrome Check Models',
        run_times=10,
+0 −2724

File deleted.

Preview size limit exceeded, changes collapsed.

+11 −59
Original line number Diff line number Diff line
@@ -15,63 +15,26 @@ Overview:

    The models are hosted on `huggingface - deepghs/monochrome_detect <https://huggingface.co/deepghs/monochrome_detect>`_.
"""
from functools import lru_cache
from typing import Optional, Tuple, Mapping

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from ..data import ImageTyping, load_image, rgb_encode
from ..utils import open_onnx_model
from ..data import ImageTyping
from ..generic import classify_predict_score, classify_predict

__all__ = [
    'get_monochrome_score',
    'is_monochrome',
]

_MODELS: Mapping[Tuple[str, bool], str] = {
    ('caformer_s36', False): 'caformer_s36_plus',
    ('caformer_s36', True): 'caformer_s36_plus_safe2',
    ('mobilenetv3', False): 'mobilenetv3_large_100',
    ('mobilenetv3', True): 'mobilenetv3_large_100_safe2',
    ('mobilenetv3_dist', False): 'mobilenetv3_large_100_dist',
    ('mobilenetv3_dist', True): 'mobilenetv3_large_100_dist_safe2',
}


@lru_cache()
def _monochrome_validate_model(model: str, safe: bool):
    return open_onnx_model(hf_hub_download(
        f'deepghs/monochrome_detect',
        f'{_MODELS[(model, safe)]}/model.onnx',
    ))

_DEFAULT_MODEL_NAME = 'mobilenetv3_large_100_dist_safe2'
_REPO_ID = 'deepghs/monochrome_detect'

def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
               normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
    image = image.resize(size, Image.BILINEAR)
    data = rgb_encode(image, order_='CHW')

    if normalize is not None:
        mean_, std_ = normalize
        mean = np.asarray([mean_]).reshape((-1, 1, 1))
        std = np.asarray([std_]).reshape((-1, 1, 1))
        data = (data - mean) / std

    return data


def get_monochrome_score(image: ImageTyping, model: str = 'mobilenetv3_dist', safe: bool = True) -> float:
def get_monochrome_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> float:
    """
    Overview:
        Get monochrome score of the given image.

    :param image: Image to predict, can be a ``PIL.Image`` object or the path of the image file.
    :param model: The model used for inference. The default value is ``mobilenetv3_dist``,
    :param model_name: The model used for inference. The default value is ``mobilenetv3_dist``,
        which offers high runtime performance. If you need better accuracy, just use ``caformer_s36``.
    :param safe: Whether to enable the safe mode. When enabled, calculations will be performed using a model
        with higher precision but lower recall. The default value is ``True``.

    Examples::
        >>> from imgutils.validate import get_monochrome_score
@@ -102,19 +65,11 @@ def get_monochrome_score(image: ImageTyping, model: str = 'mobilenetv3_dist', sa
        >>> get_monochrome_score('colored/12.jpg')
        0.0315730981528759
    """
    safe = bool(safe)
    if (model, safe) not in _MODELS:
        raise ValueError(f'Unknown model for monochrome detection - {model!r}, {safe!r}.')

    image = load_image(image, mode='RGB')
    input_data = _2d_encode(image).astype(np.float32)
    input_data = np.stack([input_data])
    output_data, = _monochrome_validate_model(model, safe).run(['output'], {'input': input_data})
    return output_data[0][0].item()
    return classify_predict_score(image, _REPO_ID, model_name)['monochrome']


def is_monochrome(image: ImageTyping, threshold: float = 0.5,
                  model: str = 'mobilenetv3_dist', safe: bool = True) -> bool:
                  model_name: str = _DEFAULT_MODEL_NAME) -> bool:
    """
    Overview:
        Predict if the image is monochrome.
@@ -122,12 +77,8 @@ def is_monochrome(image: ImageTyping, threshold: float = 0.5,
    :param image: Image to predict, can be a ``PIL.Image`` object or the path of the image file.
    :param threshold: Threshold value during prediction. If the score is higher than the threshold,
        the image will be classified as monochrome.
    :param model: The model used for inference. The default value is ``mobilenetv3_dist``,
    :param model_name: The model used for inference. The default value is ``mobilenetv3_dist``,
        which offers high runtime performance. If you need better accuracy, just use ``caformer_s36``.
    :param safe: Safe level, with optional values including ``0``, ``2``, and ``4``,
        corresponding to different levels of the model. The default value is 2.
        For more technical details about this model, please refer to:
        https://huggingface.co/deepghs/imgutils-models#monochrome .

    Examples:
        >>> import os
@@ -158,4 +109,5 @@ def is_monochrome(image: ImageTyping, threshold: float = 0.5,
        >>> is_monochrome('colored/12.jpg')
        False
    """
    return get_monochrome_score(image, model, safe) >= threshold
    type_, _ = classify_predict(image, _REPO_ID, model_name)
    return type_ == 'monochrome'
+11 −22
Original line number Diff line number Diff line
@@ -3,7 +3,10 @@ import os.path
import pytest
from hbutils.testing import tmatrix

from imgutils.validate.monochrome import get_monochrome_score, is_monochrome, _monochrome_validate_model
from imgutils.generic.classify import _open_models_for_repo_id
from imgutils.validate.monochrome import get_monochrome_score, is_monochrome, _REPO_ID

_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names


@pytest.fixture(scope='module', autouse=True)
@@ -11,7 +14,7 @@ def _release_model_after_run():
    try:
        yield
    finally:
        _monochrome_validate_model.cache_clear()
        _open_models_for_repo_id(_REPO_ID).clear()


def get_samples():
@@ -36,27 +39,13 @@ def get_samples():
class TestValidateMonochrome:
    @pytest.mark.parametrize(*tmatrix({
        ('type_', 'file'): get_samples(),
        ('model', 'safe'): [
            ('caformer_s36', False),
            ('caformer_s36', True),
            ('mobilenetv3', False),
            ('mobilenetv3', True),
            ('mobilenetv3_dist', False),
            ('mobilenetv3_dist', True),
        ],
        'model_name': _MODEL_NAMES,
    }))
    def test_monochrome_test(self, type_: str, file: str, model: str, safe: bool):
    def test_monochrome_test(self, type_: str, file: str, model_name: str):
        filename = os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', type_, file)
        if type_ == 'monochrome':
            assert get_monochrome_score(filename, model=model, safe=safe) >= 0.5
            assert is_monochrome(filename, model=model, safe=safe)
            assert get_monochrome_score(filename, model_name=model_name) >= 0.5
            assert is_monochrome(filename, model_name=model_name)
        else:
            assert get_monochrome_score(filename, model=model, safe=safe) <= 0.5
            assert not is_monochrome(filename, model=model, safe=safe)

    def test_monochrome_test_with_unknown_safe(self):
        with pytest.raises(ValueError):
            _ = get_monochrome_score(
                os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', 'normal', '2475192.jpg'),
                model='Model not found',
            )
            assert get_monochrome_score(filename, model_name=model_name) <= 0.5
            assert not is_monochrome(filename, model_name=model_name)