Commit a8baab28 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add nudenet detection support

parent 558e96d9
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ from .face import detect_faces
from .halfbody import detect_halfbody
from .hand import detect_hands
from .head import detect_heads
from .nudenet import detect_with_nudenet
from .person import detect_person
from .text import detect_text
from .visual import detection_visualize
+121 −0
Original line number Diff line number Diff line
from functools import lru_cache
from pprint import pprint
from typing import Tuple, List

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from hbutils.testing.requires.version import VersionInfo
from huggingface_hub import hf_hub_download

from imgutils.data import ImageTyping
from imgutils.utils import open_onnx_model
from test.testings import get_testfile
from ..data import load_image


def _check_compatibility() -> bool:
    import onnxruntime
    if VersionInfo(onnxruntime.__version__) < '1.18':
        raise EnvironmentError(f'Nudenet not supported on onnxruntime {onnxruntime.__version__}, '
                               f'please upgrade it to 1.18+ version.\n'
                               f'If you are running on CPU, use "pip install -U onnxruntime" .\n'
                               f'If you are running on GPU, use "pip install -U onnxruntime-gpu" .')  # pragma: no cover


_REPO_ID = 'deepghs/nudenet_onnx'


@lru_cache()
def _open_nudenet_yolo():
    return open_onnx_model(hf_hub_download(
        repo_id=_REPO_ID,
        repo_type='model',
        filename='320n.onnx',
    ))


@lru_cache()
def _open_nudenet_nms():
    return open_onnx_model(hf_hub_download(
        repo_id=_REPO_ID,
        repo_type='model',
        filename='nms-yolov8.onnx',
    ))


def _nn_preprocessing(image: ImageTyping, model_size: int = 320) \
        -> Tuple[np.ndarray, float]:
    image = load_image(image, mode='RGB', force_background='white')
    assert image.mode == 'RGB'
    mat = np.array(image)

    max_size = max(image.width, image.height)

    mat_pad = np.zeros((max_size, max_size, 3), dtype=np.uint8)
    mat_pad[:mat.shape[0], :mat.shape[1], :] = mat
    img_resized = Image.fromarray(mat_pad, mode='RGB').resize((model_size, model_size), resample=Image.BILINEAR)

    input_data = np.array(img_resized).transpose(2, 0, 1).astype(np.float32) / 255.0
    input_data = np.expand_dims(input_data, axis=0)
    return input_data, max_size / model_size


def _make_np_config(topk: int = 100, iou_threshold: float = 0.45, score_threshold: float = 0.25) -> np.ndarray:
    return np.array([topk, iou_threshold, score_threshold]).astype(np.float32)


def _nn_postprocess(selected, global_ratio: float):
    bboxes = []
    num_boxes = selected.shape[0]
    for idx in range(num_boxes):
        data = selected[idx, :]

        scores = data[4:]
        score = np.max(scores)
        label = np.argmax(scores)

        box = data[:4] * global_ratio
        x = (box[0] - 0.5 * box[2]).item()
        y = (box[1] - 0.5 * box[3]).item()
        w = box[2].item()
        h = box[3].item()

        bboxes.append(((x, y, x + w, y + h), _LABELS[label], score.item()))

    return bboxes


_LABELS = [
    "FEMALE_GENITALIA_COVERED",
    "FACE_FEMALE",
    "BUTTOCKS_EXPOSED",
    "FEMALE_BREAST_EXPOSED",
    "FEMALE_GENITALIA_EXPOSED",
    "MALE_BREAST_EXPOSED",
    "ANUS_EXPOSED",
    "FEET_EXPOSED",
    "BELLY_COVERED",
    "FEET_COVERED",
    "ARMPITS_COVERED",
    "ARMPITS_EXPOSED",
    "FACE_MALE",
    "BELLY_EXPOSED",
    "MALE_GENITALIA_EXPOSED",
    "ANUS_COVERED",
    "FEMALE_BREAST_COVERED",
    "BUTTOCKS_COVERED"
]


def detect_with_nudenet(image: ImageTyping, topk: int = 100,
                        iou_threshold: float = 0.45, score_threshold: float = 0.25) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    _check_compatibility()
    input_, global_ratio = _nn_preprocessing(image, model_size=320)
    config = _make_np_config(topk, iou_threshold, score_threshold)
    output0, = _open_nudenet_yolo().run(['output0'], {'images': input_})
    selected, = _open_nudenet_nms().run(['selected'], {'detection': output0, 'config': config})
    return _nn_postprocess(selected[0], global_ratio=global_ratio)

+1 −1
Original line number Diff line number Diff line
@@ -118,7 +118,7 @@ def _get_bounding_box_of_text(image: ImageTyping, model: str, threshold: float)
    return bboxes


@deprecated(deprecated_in="0.2.10", removed_in="0.4", current_version=__VERSION__,
@deprecated(deprecated_in="0.2.10", removed_in="0.5", current_version=__VERSION__,
            details="Use the new function :func:`imgutils.ocr.detect_text_with_ocr` instead")
def detect_text(image: ImageTyping, model: str = _DEFAULT_MODEL, threshold: float = 0.05,
                max_area_size: Optional[int] = 640):
+75 −0
Original line number Diff line number Diff line
import pytest
from PIL import Image

from imgutils.detect import detect_with_nudenet
from imgutils.detect.nudenet import _open_nudenet_nms, _open_nudenet_yolo
from ..testings import get_testfile


@pytest.fixture(scope='module', autouse=True)
def _release_model_after_run():
    try:
        yield
    finally:
        _open_nudenet_yolo.cache_clear()
        _open_nudenet_nms.cache_clear()


@pytest.fixture()
def nude_girl_file():
    return get_testfile('nude_girl.png')


@pytest.fixture()
def nude_girl_image(nude_girl_file):
    return Image.open(nude_girl_file)


@pytest.fixture()
def nude_girl_detection():
    return [
        ((321.3878631591797, 242.3542022705078, 429.8410186767578, 345.7248992919922),
         'FEMALE_BREAST_EXPOSED',
         0.832775890827179),
        ((207.8404312133789, 243.68451690673828, 307.2947006225586, 336.3175582885742),
         'FEMALE_BREAST_EXPOSED',
         0.8057667016983032),
        ((203.23711395263672,
          348.42012786865234,
          351.32117462158203,
          511.34781646728516),
         'BELLY_EXPOSED',
         0.7703637480735779),
        ((280.81117248535156,
          678.6565170288086,
          436.11827087402344,
          767.8816909790039),
         'FEET_EXPOSED',
         0.747696578502655),
        ((185.25140380859375, 518.0437889099121, 252.96240234375, 625.8465919494629),
         'FEMALE_GENITALIA_EXPOSED',
         0.7381105422973633),
        ((287.9706840515137, 124.07051467895508, 392.7693061828613, 225.3848991394043),
         'FACE_FEMALE',
         0.6556487083435059),
        ((103.20288848876953,
          564.7838439941406,
          352.05843353271484,
          707.6390075683594),
         'BUTTOCKS_EXPOSED',
         0.44306617975234985),
        ((396.1982898712158, 224.24786376953125, 450.53956413269043, 290.279541015625),
         'ARMPITS_EXPOSED',
         0.31386712193489075)
    ]


@pytest.mark.unittest
class TestDetectNudeNet:
    def test_detect_with_nudenet_file(self, nude_girl_file, nude_girl_detection):
        detection = detect_with_nudenet(nude_girl_file)
        assert detection == pytest.approx(nude_girl_detection)

    def test_detect_with_nudenet_image(self, nude_girl_image, nude_girl_detection):
        detection = detect_with_nudenet(nude_girl_image)
        assert detection == pytest.approx(nude_girl_detection)