Loading docs/source/api_doc/validate/index.rst +1 −0 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ imgutils.validate portrait rating real safe style_age teen truncate docs/source/api_doc/validate/safe.rst 0 → 100644 +21 −0 Original line number Diff line number Diff line imgutils.validate.safe ============================================= .. currentmodule:: imgutils.validate.safe .. automodule:: imgutils.validate.safe safe_check_score ----------------------------- .. autofunction:: safe_check_score safe_check ----------------------------- .. autofunction:: safe_check docs/source/api_doc/validate/safe_benchmark.plot.py 0 → 100644 +46 −0 Original line number Diff line number Diff line import os import random from huggingface_hub import HfFileSystem from natsort import natsorted from benchmark import BaseBenchmark, create_plot_cli from imgutils.validate import safe_check hf_fs = HfFileSystem() REPOSITORY = 'mf666/shit-checker' MODELS = natsorted([ os.path.splitext(os.path.relpath(file, REPOSITORY))[0] for file in hf_fs.glob(f'{REPOSITORY}/*.onnx') ]) class SafeCheckBenchmark(BaseBenchmark): def __init__(self, model): BaseBenchmark.__init__(self) self.model = model def load(self): from imgutils.validate.safe import _open_model _ = _open_model(self.model) def unload(self): from imgutils.validate.safe import _open_model _open_model.cache_clear() def run(self): image_file = random.choice(self.all_images) _ = safe_check(image_file, self.model) if __name__ == '__main__': create_plot_cli( [ (name, SafeCheckBenchmark(name)) for name in MODELS ], title='Benchmark for Safe Check Models', run_times=10, try_times=20, )() imgutils/validate/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -11,6 +11,7 @@ from .nsfw import * from .portrait import * from .rating import * from .real import * from .safe import * from .style_age import * from .teen import * from .truncate import * imgutils/validate/safe.py 0 → 100644 +96 −0 Original line number Diff line number Diff line import math import random from functools import lru_cache from typing import Mapping, Tuple import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from ..data import ImageTyping, load_image from ..utils import open_onnx_model __all__ = [ 'safe_check_score', 'safe_check', ] DEFAULT_MODEL = 'mobilenet.xs.v2' @lru_cache() def _open_model(model_name): return open_onnx_model(hf_hub_download( repo_id='mf666/shit-checker', filename=f'{model_name}.onnx' )) _DEFAULT_ORDER = 'HWC' def _get_hwc_map(order_): return tuple(_DEFAULT_ORDER.index(c) for c in order_.upper()) def _encode_channels(image, channels_order='CHW'): array = np.asarray(image.convert('RGB')) array = np.transpose(array, _get_hwc_map(channels_order)) array = (array / 255.0).astype(np.float32) assert array.dtype == np.float32 return array def _img_encode(image, size=(384, 384), normalize=(0.5, 0.5)): image = image.resize(size, Image.BILINEAR) data = _encode_channels(image, channels_order='CHW') if normalize is not None: mean_, std_ = normalize mean = np.asarray([mean_]).reshape((-1, 1, 1)) std = np.asarray([std_]).reshape((-1, 1, 1)) data = (data - mean) / std return data.astype(np.float32) def _raw_predict(images, model_name=DEFAULT_MODEL): items = [] for image in images: items.append(_img_encode(image.convert('RGB'))) input_ = np.stack(items) output, = _open_model(model_name).run(['output'], {'input': input_}) return output.mean(axis=0) _LABELS = ['polluted', 'safe'] def _pred(image, model_name=DEFAULT_MODEL, max_batch_size=8): area = image.width * image.height batch_size = int(max(min(math.ceil(area / (384 * 384)) + 1, max_batch_size), 1)) blocks = [] for _ in range(batch_size): x0 = random.randint(0, max(0, image.width - 384)) y0 = random.randint(0, max(0, image.height - 384)) x1 = min(x0 + 384, image.width) y1 = min(y0 + 384, image.height) blocks.append(image.crop((x0, y0, x1, y1))) scores = _raw_predict(blocks, model_name) return scores def safe_check_score(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8) \ -> Mapping[str, float]: image = load_image(image) _pred_result = _pred(image, model_name, max_batch_size) return dict(zip(['polluted', 'safe'], map(lambda x: x.item(), _pred_result))) def safe_check(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8) \ -> Tuple[str, float]: image = load_image(image) _pred_result = _pred(image, model_name, max_batch_size) id_ = _pred_result.argmax().item() return _LABELS[id_], _pred_result[id_].item() Loading
docs/source/api_doc/validate/index.rst +1 −0 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ imgutils.validate portrait rating real safe style_age teen truncate
docs/source/api_doc/validate/safe.rst 0 → 100644 +21 −0 Original line number Diff line number Diff line imgutils.validate.safe ============================================= .. currentmodule:: imgutils.validate.safe .. automodule:: imgutils.validate.safe safe_check_score ----------------------------- .. autofunction:: safe_check_score safe_check ----------------------------- .. autofunction:: safe_check
docs/source/api_doc/validate/safe_benchmark.plot.py 0 → 100644 +46 −0 Original line number Diff line number Diff line import os import random from huggingface_hub import HfFileSystem from natsort import natsorted from benchmark import BaseBenchmark, create_plot_cli from imgutils.validate import safe_check hf_fs = HfFileSystem() REPOSITORY = 'mf666/shit-checker' MODELS = natsorted([ os.path.splitext(os.path.relpath(file, REPOSITORY))[0] for file in hf_fs.glob(f'{REPOSITORY}/*.onnx') ]) class SafeCheckBenchmark(BaseBenchmark): def __init__(self, model): BaseBenchmark.__init__(self) self.model = model def load(self): from imgutils.validate.safe import _open_model _ = _open_model(self.model) def unload(self): from imgutils.validate.safe import _open_model _open_model.cache_clear() def run(self): image_file = random.choice(self.all_images) _ = safe_check(image_file, self.model) if __name__ == '__main__': create_plot_cli( [ (name, SafeCheckBenchmark(name)) for name in MODELS ], title='Benchmark for Safe Check Models', run_times=10, try_times=20, )()
imgutils/validate/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -11,6 +11,7 @@ from .nsfw import * from .portrait import * from .rating import * from .real import * from .safe import * from .style_age import * from .teen import * from .truncate import *
imgutils/validate/safe.py 0 → 100644 +96 −0 Original line number Diff line number Diff line import math import random from functools import lru_cache from typing import Mapping, Tuple import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from ..data import ImageTyping, load_image from ..utils import open_onnx_model __all__ = [ 'safe_check_score', 'safe_check', ] DEFAULT_MODEL = 'mobilenet.xs.v2' @lru_cache() def _open_model(model_name): return open_onnx_model(hf_hub_download( repo_id='mf666/shit-checker', filename=f'{model_name}.onnx' )) _DEFAULT_ORDER = 'HWC' def _get_hwc_map(order_): return tuple(_DEFAULT_ORDER.index(c) for c in order_.upper()) def _encode_channels(image, channels_order='CHW'): array = np.asarray(image.convert('RGB')) array = np.transpose(array, _get_hwc_map(channels_order)) array = (array / 255.0).astype(np.float32) assert array.dtype == np.float32 return array def _img_encode(image, size=(384, 384), normalize=(0.5, 0.5)): image = image.resize(size, Image.BILINEAR) data = _encode_channels(image, channels_order='CHW') if normalize is not None: mean_, std_ = normalize mean = np.asarray([mean_]).reshape((-1, 1, 1)) std = np.asarray([std_]).reshape((-1, 1, 1)) data = (data - mean) / std return data.astype(np.float32) def _raw_predict(images, model_name=DEFAULT_MODEL): items = [] for image in images: items.append(_img_encode(image.convert('RGB'))) input_ = np.stack(items) output, = _open_model(model_name).run(['output'], {'input': input_}) return output.mean(axis=0) _LABELS = ['polluted', 'safe'] def _pred(image, model_name=DEFAULT_MODEL, max_batch_size=8): area = image.width * image.height batch_size = int(max(min(math.ceil(area / (384 * 384)) + 1, max_batch_size), 1)) blocks = [] for _ in range(batch_size): x0 = random.randint(0, max(0, image.width - 384)) y0 = random.randint(0, max(0, image.height - 384)) x1 = min(x0 + 384, image.width) y1 = min(y0 + 384, image.height) blocks.append(image.crop((x0, y0, x1, y1))) scores = _raw_predict(blocks, model_name) return scores def safe_check_score(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8) \ -> Mapping[str, float]: image = load_image(image) _pred_result = _pred(image, model_name, max_batch_size) return dict(zip(['polluted', 'safe'], map(lambda x: x.item(), _pred_result))) def safe_check(image: ImageTyping, model_name: str = DEFAULT_MODEL, max_batch_size: int = 8) \ -> Tuple[str, float]: image = load_image(image) _pred_result = _pred(image, model_name, max_batch_size) id_ = _pred_result.argmax().item() return _LABELS[id_], _pred_result[id_].item()