Commit 2b746f5c authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): use new monochrome model

parent 0dfc7e58
Loading
Loading
Loading
Loading
+22 −18
Original line number Diff line number Diff line
@@ -28,18 +28,18 @@ __all__ = [
    'is_monochrome',
]

_MODELS: Mapping[int, str] = {
    0: 'monochrome-caformer-110.onnx',
    2: 'monochrome-caformer_safe2-80.onnx',
    4: 'monochrome-caformer_safe4-70.onnx',
_MODELS: Mapping[Tuple[str, bool], str] = {
    ('caformer_s36', False): 'caformer_s36',
    ('mobilenetv3', False): 'mobilenetv3_large_100',
    ('mobilenetv3', True): 'mobilenetv3_large_100_safe2',
}


@lru_cache()
def _monochrome_validate_model(ckpt):
def _monochrome_validate_model(model: str, safe: bool):
    return open_onnx_model(hf_hub_download(
        'deepghs/imgutils-models',
        f'monochrome/{ckpt}'
        f'deepghs/monochrome_detect',
        f'{_MODELS[(model, safe)]}/model.onnx',
    ))


@@ -57,16 +57,16 @@ def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
    return data


def get_monochrome_score(image: ImageTyping, safe: int = 2) -> float:
def get_monochrome_score(image: ImageTyping, model: str = 'mobilenetv3', safe: bool = True) -> float:
    """
    Overview:
        Get monochrome score of the given image.

    :param image: Image to predict, can be a ``PIL.Image`` object or the path of the image file.
    :param safe: Safe level, with optional values including ``0``, ``2``, and ``4``,
        corresponding to different levels of the model. The default value is 2.
        For more technical details about this model, please refer to:
        https://huggingface.co/deepghs/imgutils-models#monochrome .
    :param model: The model used for inference. The default value is ``mobilenetv3``,
        which offers high runtime performance.
    :param safe: Whether to enable the safe mode. When enabled, calculations will be performed using a model
        with higher precision but lower recall. The default value is ``True``.

    Examples::
        >>> import os
@@ -97,17 +97,19 @@ def get_monochrome_score(image: ImageTyping, safe: int = 2) -> float:
        >>> get_monochrome_score('colored/12.jpg')
        0.025258518755435944
    """
    if safe not in _MODELS:
        raise ValueError(f'Safe level should be one of {set(sorted(_MODELS.keys()))!r}, but {safe!r} found.')
    safe = bool(safe)
    if (model, safe) not in _MODELS:
        raise ValueError(f'Unknown model for monochrome detection - {model!r}, {safe!r}.')

    image = load_image(image, mode='RGB')
    input_data = _2d_encode(image).astype(np.float32)
    input_data = np.stack([input_data])
    output_data, = _monochrome_validate_model(_MODELS[safe]).run(['output'], {'input': input_data})
    return float(output_data[0][1])
    output_data, = _monochrome_validate_model(model, safe).run(['output'], {'input': input_data})
    return output_data[0][0].item()


def is_monochrome(image: ImageTyping, threshold: float = 0.5, safe: int = 2) -> bool:
def is_monochrome(image: ImageTyping, threshold: float = 0.5,
                  model: str = 'mobilenetv3', safe: bool = True) -> bool:
    """
    Overview:
        Predict if the image is monochrome.
@@ -115,6 +117,8 @@ def is_monochrome(image: ImageTyping, threshold: float = 0.5, safe: int = 2) ->
    :param image: Image to predict, can be a ``PIL.Image`` object or the path of the image file.
    :param threshold: Threshold value during prediction. If the score is higher than the threshold,
        the image will be classified as monochrome.
    :param model: The model used for inference. The default value is ``mobilenetv3``,
        which offers high runtime performance.
    :param safe: Safe level, with optional values including ``0``, ``2``, and ``4``,
        corresponding to different levels of the model. The default value is 2.
        For more technical details about this model, please refer to:
@@ -149,4 +153,4 @@ def is_monochrome(image: ImageTyping, threshold: float = 0.5, safe: int = 2) ->
        >>> is_monochrome('colored/12.jpg')
        False
    """
    return get_monochrome_score(image, safe) >= threshold
    return get_monochrome_score(image, model, safe) >= threshold
+24 −24
Original line number Diff line number Diff line
@@ -16,18 +16,19 @@ def _release_model_after_run():

def get_samples():
    return [
        ('monochrome', '143640.jpg'), ('monochrome', '2165075.jpg'), ('monochrome', '2267010.jpg'),
        ('monochrome', '2642558.jpg'), ('monochrome', '3141176.jpg'), ('monochrome', '4530291.jpg'),
        ('monochrome', '4589191.jpg'), ('monochrome', '5182260.jpg'), ('monochrome', '5376761.jpg'),
        ('monochrome', '5608827.jpg'), ('monochrome', '5992785.jpg'), ('monochrome', '6126963.jpg'),
        ('monochrome', '6128358.jpg'), ('monochrome', '6131733.jpg'), ('monochrome', '6154723.jpg'),
        ('monochrome', '843419.jpg'), ('monochrome', '84584446_p3_master1200.jpg'),
        ('monochrome', '87392919_p26_master1200.jpg'),

        ('normal', '2034501.jpg'), ('normal', '2160617.jpg'), ('normal', '3446505.jpg'),
        ('normal', '3725624.jpg'), ('normal', '3899045.jpg'), ('normal', '4278075.jpg'),
        ('normal', '4897680.jpg'), ('normal', '5531563.jpg'), ('normal', '62722650_p14_master1200.jpg'),
        ('normal', '86243980_p0_master1200.jpg'), ('normal', '89270548_p3_master1200.jpg')
        ('monochrome', '6130053.jpg'),
        ('monochrome', '6125854(第 3 个复件).jpg'),
        ('monochrome', '5221834.jpg'),
        ('monochrome', '1951253.jpg'),
        ('monochrome', '4879658.jpg'),
        ('monochrome', '80750471_p3_master1200.jpg'),

        ('normal', '54566940_p0_master1200.jpg'),
        ('normal', '60817155_p18_master1200.jpg'),
        ('normal', '4945494.jpg'),
        ('normal', '4008375.jpg'),
        ('normal', '2416278.jpg'),
        ('normal', '842709.jpg')
    ]


@@ -35,25 +36,24 @@ def get_samples():
class TestValidateMonochrome:
    @pytest.mark.parametrize(*tmatrix({
        ('type_', 'file'): get_samples(),
        'safe': [0, 2, 4],
        ('model', 'safe'): [
            ('caformer_s36', False),
            ('mobilenetv3', False),
            ('mobilenetv3', True),
        ],
    }))
    def test_monochrome_test(self, type_: str, file: str, safe: int):
    def test_monochrome_test(self, type_: str, file: str, model: str, safe: bool):
        filename = os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', type_, file)
        if type_ == 'monochrome':
            assert get_monochrome_score(filename, safe=safe) >= 0.5
            assert is_monochrome(filename, safe=safe)
            assert get_monochrome_score(filename, model=model, safe=safe) >= 0.5
            assert is_monochrome(filename, model=model, safe=safe)
        else:
            assert get_monochrome_score(filename, safe=safe) <= 0.5
            assert not is_monochrome(filename, safe=safe)
            assert get_monochrome_score(filename, model=model, safe=safe) <= 0.5
            assert not is_monochrome(filename, model=model, safe=safe)

    def test_monochrome_test_with_unknown_safe(self):
        with pytest.raises(ValueError):
            _ = get_monochrome_score(
                os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', 'normal', '2475192.jpg'),
                safe=100
            )
        with pytest.raises(ValueError):
            _ = is_monochrome(
                os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', 'normal', '2475192.jpg'),
                safe=100
                model='Model not found',
            )