Commit b846435c authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add docs here

parent 731188ac
Loading
Loading
Loading
Loading
+43 −0
Original line number Diff line number Diff line
"""
Overview:
    Check if the images are polluted or safe.

    This is an overall benchmark of all the safe check models:

    .. image:: safe_benchmark.plot.py.svg
        :align: center

    Inspired from `mf666/shit-checker <https://huggingface.co/spaces/mf666/shit-checker>`_.
"""
import math
import random
from functools import lru_cache
@@ -20,6 +31,14 @@ DEFAULT_MODEL = 'mobilenet.xs.v2'

@lru_cache()
def _open_model(model_name):
    """
    Open the ONNX model specified by the model name.

    :param model_name: The name of the model.
    :type model_name: str
    :return: The opened ONNX model.
    :rtype: onnx.ModelProto
    """
    return open_onnx_model(hf_hub_download(
        repo_id='mf666/shit-checker',
        filename=f'{model_name}.onnx'
@@ -83,6 +102,18 @@ def _pred(image, model_name=DEFAULT_MODEL, max_batch_size=8):

def safe_check_score(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8) \
        -> Mapping[str, float]:
    """
    Check the safety score of an image.

    :param image: The image to check.
    :type image: ImageTyping
    :param model_name: The name of the safety model.
    :type model_name: str
    :param max_batch_size: The maximum batch size for prediction.
    :type max_batch_size: int
    :return: A mapping of safety labels and their corresponding scores.
    :rtype: Mapping[str, float]
    """
    image = load_image(image)
    _pred_result = _pred(image, model_name, max_batch_size)
    return dict(zip(['polluted', 'safe'], map(lambda x: x.item(), _pred_result)))
@@ -90,6 +121,18 @@ def safe_check_score(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_ba

def safe_check(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8) \
        -> Tuple[str, float]:
    """
    Check the safety label and score of an image.

    :param image: The image to check.
    :type image: ImageTyping
    :param model_name: The name of the safety model.
    :type model_name: str
    :param max_batch_size: The maximum batch size for prediction.
    :type max_batch_size: int
    :return: A tuple containing the safety label and score.
    :rtype: Tuple[str, float]
    """
    image = load_image(image)
    _pred_result = _pred(image, model_name, max_batch_size)
    id_ = _pred_result.argmax().item()