Unverified Commit ab56db74 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #26 from deepghs/dev/monochrome

dev(narugo): use dist model for monochrome detect
parents 2696218e 466890d5
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -30,6 +30,8 @@ if __name__ == '__main__':
            ('caformer_s36 (safe)', MonochromeBenchmark('caformer_s36', True)),
            ('mobilenetv3 (unsafe)', MonochromeBenchmark('mobilenetv3', False)),
            ('mobilenetv3 (safe)', MonochromeBenchmark('mobilenetv3', True)),
            ('mobilenetv3_dist (unsafe)', MonochromeBenchmark('mobilenetv3_dist', False)),
            ('mobilenetv3_dist (safe)', MonochromeBenchmark('mobilenetv3_dist', True)),
        ],
        title='Benchmark for Monochrome Check Models',
        run_times=10,
+578 −350

File changed.

Preview size limit exceeded, changes collapsed.

+6 −4
Original line number Diff line number Diff line
@@ -35,6 +35,8 @@ _MODELS: Mapping[Tuple[str, bool], str] = {
    ('caformer_s36', True): 'caformer_s36_plus_safe2',
    ('mobilenetv3', False): 'mobilenetv3_large_100',
    ('mobilenetv3', True): 'mobilenetv3_large_100_safe2',
    ('mobilenetv3_dist', False): 'mobilenetv3_large_100_dist',
    ('mobilenetv3_dist', True): 'mobilenetv3_large_100_dist_safe2',
}


@@ -60,13 +62,13 @@ def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
    return data


def get_monochrome_score(image: ImageTyping, model: str = 'mobilenetv3', safe: bool = True) -> float:
def get_monochrome_score(image: ImageTyping, model: str = 'mobilenetv3_dist', safe: bool = True) -> float:
    """
    Overview:
        Get monochrome score of the given image.

    :param image: Image to predict, can be a ``PIL.Image`` object or the path of the image file.
    :param model: The model used for inference. The default value is ``mobilenetv3``,
    :param model: The model used for inference. The default value is ``mobilenetv3_dist``,
        which offers high runtime performance. If you need better accuracy, just use ``caformer_s36``.
    :param safe: Whether to enable the safe mode. When enabled, calculations will be performed using a model
        with higher precision but lower recall. The default value is ``True``.
@@ -112,7 +114,7 @@ def get_monochrome_score(image: ImageTyping, model: str = 'mobilenetv3', safe: b


def is_monochrome(image: ImageTyping, threshold: float = 0.5,
                  model: str = 'mobilenetv3', safe: bool = True) -> bool:
                  model: str = 'mobilenetv3_dist', safe: bool = True) -> bool:
    """
    Overview:
        Predict if the image is monochrome.
@@ -120,7 +122,7 @@ def is_monochrome(image: ImageTyping, threshold: float = 0.5,
    :param image: Image to predict, can be a ``PIL.Image`` object or the path of the image file.
    :param threshold: Threshold value during prediction. If the score is higher than the threshold,
        the image will be classified as monochrome.
    :param model: The model used for inference. The default value is ``mobilenetv3``,
    :param model: The model used for inference. The default value is ``mobilenetv3_dist``,
        which offers high runtime performance. If you need better accuracy, just use ``caformer_s36``.
    :param safe: Safe level, with optional values including ``0``, ``2``, and ``4``,
        corresponding to different levels of the model. The default value is 2.
+2 −0
Original line number Diff line number Diff line
@@ -41,6 +41,8 @@ class TestValidateMonochrome:
            ('caformer_s36', True),
            ('mobilenetv3', False),
            ('mobilenetv3', True),
            ('mobilenetv3_dist', False),
            ('mobilenetv3_dist', True),
        ],
    }))
    def test_monochrome_test(self, type_: str, file: str, model: str, safe: bool):