Commit a25fc1e0 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add safe check model

parent e9a2e578
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ imgutils.validate
    portrait
    rating
    real
    safe
    style_age
    teen
    truncate
+21 −0
Original line number Diff line number Diff line
imgutils.validate.safe
=============================================

.. currentmodule:: imgutils.validate.safe

.. automodule:: imgutils.validate.safe


safe_check_score
-----------------------------

.. autofunction:: safe_check_score



safe_check
-----------------------------

.. autofunction:: safe_check

+46 −0
Original line number Diff line number Diff line
import os
import random

from huggingface_hub import HfFileSystem
from natsort import natsorted

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.validate import safe_check

hf_fs = HfFileSystem()

REPOSITORY = 'mf666/shit-checker'
MODELS = natsorted([
    os.path.splitext(os.path.relpath(file, REPOSITORY))[0]
    for file in hf_fs.glob(f'{REPOSITORY}/*.onnx')
])


class SafeCheckBenchmark(BaseBenchmark):
    def __init__(self, model):
        BaseBenchmark.__init__(self)
        self.model = model

    def load(self):
        from imgutils.validate.safe import _open_model
        _ = _open_model(self.model)

    def unload(self):
        from imgutils.validate.safe import _open_model
        _open_model.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = safe_check(image_file, self.model)


if __name__ == '__main__':
    create_plot_cli(
        [
            (name, SafeCheckBenchmark(name))
            for name in MODELS
        ],
        title='Benchmark for Safe Check Models',
        run_times=10,
        try_times=20,
    )()
+1 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ from .nsfw import *
from .portrait import *
from .rating import *
from .real import *
from .safe import *
from .style_age import *
from .teen import *
from .truncate import *
+96 −0
Original line number Diff line number Diff line
import math
import random
from functools import lru_cache
from typing import Mapping, Tuple

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from ..data import ImageTyping, load_image
from ..utils import open_onnx_model

__all__ = [
    'safe_check_score',
    'safe_check',
]

DEFAULT_MODEL = 'mobilenet.xs.v2'


@lru_cache()
def _open_model(model_name):
    return open_onnx_model(hf_hub_download(
        repo_id='mf666/shit-checker',
        filename=f'{model_name}.onnx'
    ))


_DEFAULT_ORDER = 'HWC'


def _get_hwc_map(order_):
    return tuple(_DEFAULT_ORDER.index(c) for c in order_.upper())


def _encode_channels(image, channels_order='CHW'):
    array = np.asarray(image.convert('RGB'))
    array = np.transpose(array, _get_hwc_map(channels_order))
    array = (array / 255.0).astype(np.float32)
    assert array.dtype == np.float32
    return array


def _img_encode(image, size=(384, 384), normalize=(0.5, 0.5)):
    image = image.resize(size, Image.BILINEAR)
    data = _encode_channels(image, channels_order='CHW')

    if normalize is not None:
        mean_, std_ = normalize
        mean = np.asarray([mean_]).reshape((-1, 1, 1))
        std = np.asarray([std_]).reshape((-1, 1, 1))
        data = (data - mean) / std

    return data.astype(np.float32)


def _raw_predict(images, model_name=DEFAULT_MODEL):
    items = []
    for image in images:
        items.append(_img_encode(image.convert('RGB')))
    input_ = np.stack(items)
    output, = _open_model(model_name).run(['output'], {'input': input_})
    return output.mean(axis=0)


_LABELS = ['polluted', 'safe']


def _pred(image, model_name=DEFAULT_MODEL, max_batch_size=8):
    area = image.width * image.height
    batch_size = int(max(min(math.ceil(area / (384 * 384)) + 1, max_batch_size), 1))
    blocks = []
    for _ in range(batch_size):
        x0 = random.randint(0, max(0, image.width - 384))
        y0 = random.randint(0, max(0, image.height - 384))
        x1 = min(x0 + 384, image.width)
        y1 = min(y0 + 384, image.height)
        blocks.append(image.crop((x0, y0, x1, y1)))

    scores = _raw_predict(blocks, model_name)
    return scores


def safe_check_score(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8) \
        -> Mapping[str, float]:
    image = load_image(image)
    _pred_result = _pred(image, model_name, max_batch_size)
    return dict(zip(['polluted', 'safe'], map(lambda x: x.item(), _pred_result)))


def safe_check(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8) \
        -> Tuple[str, float]:
    image = load_image(image)
    _pred_result = _pred(image, model_name, max_batch_size)
    id_ = _pred_result.argmax().item()
    return _LABELS[id_], _pred_result[id_].item()
Loading