Unverified Commit 3b0b425e authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #22 from deepghs/dev/monochrome

dev(narugo): use new model of monochrome detection
parents 0dfc7e58 4352ea46
Loading
Loading
Loading
Loading
+30.6 KiB (53.9 KiB)
Loading image diff...
+42 −42

File changed.

Preview size limit exceeded, changes collapsed.

+8 −7
Original line number Diff line number Diff line
@@ -2,17 +2,17 @@ import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.validate import get_monochrome_score
from imgutils.validate.monochrome import _MODELS


class MonochromeBenchmark(BaseBenchmark):
    def __init__(self, safe=0):
    def __init__(self, model, safe):
        BaseBenchmark.__init__(self)
        self.model = model
        self.safe = safe

    def load(self):
        from imgutils.validate.monochrome import _monochrome_validate_model
        _ = _monochrome_validate_model(_MODELS[self.safe])
        _ = _monochrome_validate_model(self.model, self.safe)

    def unload(self):
        from imgutils.validate.monochrome import _monochrome_validate_model
@@ -20,15 +20,16 @@ class MonochromeBenchmark(BaseBenchmark):

    def run(self):
        image_file = random.choice(self.all_images)
        _ = get_monochrome_score(image_file, safe=self.safe)
        _ = get_monochrome_score(image_file, model=self.model, safe=self.safe)


if __name__ == '__main__':
    create_plot_cli(
        [
            ('monochrome', MonochromeBenchmark()),
            ('monochrome (safe 2)', MonochromeBenchmark(2)),
            ('monochrome (safe 4)', MonochromeBenchmark(4)),
            ('caformer_s36 (unsafe)', MonochromeBenchmark('caformer_s36', False)),
            ('caformer_s36 (safe)', MonochromeBenchmark('caformer_s36', True)),
            ('mobilenetv3 (unsafe)', MonochromeBenchmark('mobilenetv3', False)),
            ('mobilenetv3 (safe)', MonochromeBenchmark('mobilenetv3', True)),
        ],
        title='Benchmark for Monochrome Check Models',
        run_times=10,
+531 −398

File changed.

Preview size limit exceeded, changes collapsed.

+36 −31
Original line number Diff line number Diff line
@@ -28,18 +28,19 @@ __all__ = [
    'is_monochrome',
]

_MODELS: Mapping[int, str] = {
    0: 'monochrome-caformer-110.onnx',
    2: 'monochrome-caformer_safe2-80.onnx',
    4: 'monochrome-caformer_safe4-70.onnx',
_MODELS: Mapping[Tuple[str, bool], str] = {
    ('caformer_s36', False): 'caformer_s36_plus',
    ('caformer_s36', True): 'caformer_s36_plus_safe2',
    ('mobilenetv3', False): 'mobilenetv3_large_100',
    ('mobilenetv3', True): 'mobilenetv3_large_100_safe2',
}


@lru_cache()
def _monochrome_validate_model(ckpt):
def _monochrome_validate_model(model: str, safe: bool):
    return open_onnx_model(hf_hub_download(
        'deepghs/imgutils-models',
        f'monochrome/{ckpt}'
        f'deepghs/monochrome_detect',
        f'{_MODELS[(model, safe)]}/model.onnx',
    ))


@@ -57,57 +58,59 @@ def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
    return data


def get_monochrome_score(image: ImageTyping, safe: int = 2) -> float:
def get_monochrome_score(image: ImageTyping, model: str = 'mobilenetv3', safe: bool = True) -> float:
    """
    Overview:
        Get monochrome score of the given image.

    :param image: Image to predict, can be a ``PIL.Image`` object or the path of the image file.
    :param safe: Safe level, with optional values including ``0``, ``2``, and ``4``,
        corresponding to different levels of the model. The default value is 2.
        For more technical details about this model, please refer to:
        https://huggingface.co/deepghs/imgutils-models#monochrome .
    :param model: The model used for inference. The default value is ``mobilenetv3``,
        which offers high runtime performance. If you need better accuracy, just use ``caformer_s36``.
    :param safe: Whether to enable the safe mode. When enabled, calculations will be performed using a model
        with higher precision but lower recall. The default value is ``True``.

    Examples::
        >>> import os
        >>> from imgutils.validate import get_monochrome_score
        >>>
        >>> get_monochrome_score('mono/1.jpg')  # monochrome images
        0.9789709448814392
        0.9614395499229431
        >>> get_monochrome_score('mono/2.jpg')
        0.973383903503418
        0.9458909034729004
        >>> get_monochrome_score('mono/3.jpg')
        0.9789378046989441
        0.9559807777404785
        >>> get_monochrome_score('mono/4.jpg')
        0.9920350909233093
        0.9651952981948853
        >>> get_monochrome_score('mono/5.jpg')
        0.9865685701370239
        0.9379720687866211
        >>> get_monochrome_score('mono/6.jpg')
        0.9589458703994751
        0.8814834356307983
        >>>
        >>> get_monochrome_score('colored/7.jpg')  # colored images
        0.019315600395202637
        0.03941023349761963
        >>> get_monochrome_score('colored/8.jpg')
        0.008630834519863129
        0.07492382079362869
        >>> get_monochrome_score('colored/9.jpg')
        0.08635691553354263
        0.09546589106321335
        >>> get_monochrome_score('colored/10.jpg')
        0.01357574388384819
        0.016521310433745384
        >>> get_monochrome_score('colored/11.jpg')
        0.00710612116381526
        0.005693843588232994
        >>> get_monochrome_score('colored/12.jpg')
        0.025258518755435944
        0.0315730981528759
    """
    if safe not in _MODELS:
        raise ValueError(f'Safe level should be one of {set(sorted(_MODELS.keys()))!r}, but {safe!r} found.')
    safe = bool(safe)
    if (model, safe) not in _MODELS:
        raise ValueError(f'Unknown model for monochrome detection - {model!r}, {safe!r}.')

    image = load_image(image, mode='RGB')
    input_data = _2d_encode(image).astype(np.float32)
    input_data = np.stack([input_data])
    output_data, = _monochrome_validate_model(_MODELS[safe]).run(['output'], {'input': input_data})
    return float(output_data[0][1])
    output_data, = _monochrome_validate_model(model, safe).run(['output'], {'input': input_data})
    return output_data[0][0].item()


def is_monochrome(image: ImageTyping, threshold: float = 0.5, safe: int = 2) -> bool:
def is_monochrome(image: ImageTyping, threshold: float = 0.5,
                  model: str = 'mobilenetv3', safe: bool = True) -> bool:
    """
    Overview:
        Predict if the image is monochrome.
@@ -115,6 +118,8 @@ def is_monochrome(image: ImageTyping, threshold: float = 0.5, safe: int = 2) ->
    :param image: Image to predict, can be a ``PIL.Image`` object or the path of the image file.
    :param threshold: Threshold value during prediction. If the score is higher than the threshold,
        the image will be classified as monochrome.
    :param model: The model used for inference. The default value is ``mobilenetv3``,
        which offers high runtime performance. If you need better accuracy, just use ``caformer_s36``.
    :param safe: Safe level, with optional values including ``0``, ``2``, and ``4``,
        corresponding to different levels of the model. The default value is 2.
        For more technical details about this model, please refer to:
@@ -149,4 +154,4 @@ def is_monochrome(image: ImageTyping, threshold: float = 0.5, safe: int = 2) ->
        >>> is_monochrome('colored/12.jpg')
        False
    """
    return get_monochrome_score(image, safe) >= threshold
    return get_monochrome_score(image, model, safe) >= threshold
Loading