Loading imgutils/validate/safe.py +43 −0 Original line number Diff line number Diff line """ Overview: Check if the images are polluted or safe. This is an overall benchmark of all the safe check models: .. image:: safe_benchmark.plot.py.svg :align: center Inspired from `mf666/shit-checker <https://huggingface.co/spaces/mf666/shit-checker>`_. """ import math import random from functools import lru_cache Loading @@ -20,6 +31,14 @@ DEFAULT_MODEL = 'mobilenet.xs.v2' @lru_cache() def _open_model(model_name): """ Open the ONNX model specified by the model name. :param model_name: The name of the model. :type model_name: str :return: The opened ONNX model. :rtype: onnx.ModelProto """ return open_onnx_model(hf_hub_download( repo_id='mf666/shit-checker', filename=f'{model_name}.onnx' Loading Loading @@ -83,6 +102,18 @@ def _pred(image, model_name=DEFAULT_MODEL, max_batch_size=8): def safe_check_score(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8) \ -> Mapping[str, float]: """ Check the safety score of an image. :param image: The image to check. :type image: ImageTyping :param model_name: The name of the safety model. :type model_name: str :param max_batch_size: The maximum batch size for prediction. :type max_batch_size: int :return: A mapping of safety labels and their corresponding scores. :rtype: Mapping[str, float] """ image = load_image(image) _pred_result = _pred(image, model_name, max_batch_size) return dict(zip(['polluted', 'safe'], map(lambda x: x.item(), _pred_result))) Loading @@ -90,6 +121,18 @@ def safe_check_score(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_ba def safe_check(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8) \ -> Tuple[str, float]: """ Check the safety label and score of an image. :param image: The image to check. :type image: ImageTyping :param model_name: The name of the safety model. :type model_name: str :param max_batch_size: The maximum batch size for prediction. :type max_batch_size: int :return: A tuple containing the safety label and score. :rtype: Tuple[str, float] """ image = load_image(image) _pred_result = _pred(image, model_name, max_batch_size) id_ = _pred_result.argmax().item() Loading Loading
imgutils/validate/safe.py +43 −0 Original line number Diff line number Diff line """ Overview: Check if the images are polluted or safe. This is an overall benchmark of all the safe check models: .. image:: safe_benchmark.plot.py.svg :align: center Inspired from `mf666/shit-checker <https://huggingface.co/spaces/mf666/shit-checker>`_. """ import math import random from functools import lru_cache Loading @@ -20,6 +31,14 @@ DEFAULT_MODEL = 'mobilenet.xs.v2' @lru_cache() def _open_model(model_name): """ Open the ONNX model specified by the model name. :param model_name: The name of the model. :type model_name: str :return: The opened ONNX model. :rtype: onnx.ModelProto """ return open_onnx_model(hf_hub_download( repo_id='mf666/shit-checker', filename=f'{model_name}.onnx' Loading Loading @@ -83,6 +102,18 @@ def _pred(image, model_name=DEFAULT_MODEL, max_batch_size=8): def safe_check_score(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8) \ -> Mapping[str, float]: """ Check the safety score of an image. :param image: The image to check. :type image: ImageTyping :param model_name: The name of the safety model. :type model_name: str :param max_batch_size: The maximum batch size for prediction. :type max_batch_size: int :return: A mapping of safety labels and their corresponding scores. :rtype: Mapping[str, float] """ image = load_image(image) _pred_result = _pred(image, model_name, max_batch_size) return dict(zip(['polluted', 'safe'], map(lambda x: x.item(), _pred_result))) Loading @@ -90,6 +121,18 @@ def safe_check_score(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_ba def safe_check(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8) \ -> Tuple[str, float]: """ Check the safety label and score of an image. :param image: The image to check. :type image: ImageTyping :param model_name: The name of the safety model. :type model_name: str :param max_batch_size: The maximum batch size for prediction. :type max_batch_size: int :return: A tuple containing the safety label and score. :rtype: Tuple[str, float] """ image = load_image(image) _pred_result = _pred(image, model_name, max_batch_size) id_ = _pred_result.argmax().item() Loading