Commit f54e5ee7 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update for booru_yolo

parent c0963685
Loading
Loading
Loading
Loading
+5 −7
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ import random
from hfutils.operate import get_hf_fs

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.detect import detect_with_booru_yolo
from imgutils.detect.booru_yolo import detect_with_booru_yolo, _REPO_ID

repository = 'deepghs/booru_yolo'
hf_fs = get_hf_fs()
@@ -20,14 +20,12 @@ class BooruYOLODetectBenchmark(BaseBenchmark):
        self.model_name = model_name

    def load(self):
        from imgutils.detect.booru_yolo import _open_booru_yolo_model, _get_booru_yolo_labels
        _ = _open_booru_yolo_model(self.model_name)
        _ = _get_booru_yolo_labels(self.model_name)
        from imgutils.generic.yolo import _open_models_for_repo_id
        _open_models_for_repo_id(_REPO_ID)._open_model(self.model_name)

    def unload(self):
        from imgutils.detect.booru_yolo import _open_booru_yolo_model, _get_booru_yolo_labels
        _open_booru_yolo_model.cache_clear()
        _get_booru_yolo_labels.cache_clear()
        from imgutils.generic.yolo import _open_models_for_repo_id
        _open_models_for_repo_id.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
+0 −2937

File deleted.

Preview size limit exceeded, changes collapsed.

+11 −94
Original line number Diff line number Diff line
@@ -170,98 +170,17 @@ Overview:
           with outstanding mAP scores, though noted to be not completely fair due to training set biases.

"""
import ast
from functools import lru_cache
from typing import Tuple, List

from huggingface_hub import hf_hub_download

from ._yolo import _data_postprocess, _image_preprocess
from ..data import ImageTyping, load_image, rgb_encode
from ..utils import open_onnx_model


@lru_cache()
def _open_booru_yolo_model(model_name: str):
    """
    Open and cache the Booru YOLO model.

    :param model_name: Name of the model to open.
    :type model_name: str
    :return: Opened ONNX model.
    :rtype: onnxruntime.InferenceSession

    This function downloads the specified model from the Hugging Face Hub and opens it as an ONNX model.
    The result is cached to avoid repeated downloads and openings.
    """
    return open_onnx_model(hf_hub_download(
        repo_id='deepghs/booru_yolo',
        repo_type='model',
        filename=f'{model_name}/model.onnx'
    ))


@lru_cache()
def _get_booru_yolo_labels(model_name: str):
    """
    Retrieve and cache the labels for the Booru YOLO model.

    :param model_name: Name of the model to get labels for.
    :type model_name: str
    :return: List of label names.
    :rtype: List[str]

    This function opens the specified model, extracts the label information from its metadata,
    and returns a list of label names. The result is cached for efficiency.
    """
    model = _open_booru_yolo_model(model_name)
    model_metadata = model.get_modelmeta()
    names_map = _safe_eval_names_str(model_metadata.custom_metadata_map['names'])
    labels = ['<Unknown>'] * (max(names_map.keys()) + 1)
    for id_, name in names_map.items():
        labels[id_] = name
    return labels


def _safe_eval_names_str(names_str):
    """
    Safely evaluate the names string from model metadata.

    :param names_str: String representation of names dictionary.
    :type names_str: str
    :return: Dictionary of name mappings.
    :rtype: dict
    :raises RuntimeError: If an invalid key or value type is encountered.

    This function parses the names string from the model metadata, ensuring that
    only string and number literals are evaluated for safety.
    """
    node = ast.parse(names_str, mode='eval')
    result = {}
    # noinspection PyUnresolvedReferences
    for key, value in zip(node.body.keys, node.body.values):
        if isinstance(key, (ast.Str, ast.Num)):
            key = ast.literal_eval(key)
        else:
            raise RuntimeError(f"Invalid key type: {key!r}, this should be a bug, "
                               f"please open an issue to dghs-imgutils.")  # pragma: no cover

        if isinstance(value, (ast.Str, ast.Num)):
            value = ast.literal_eval(value)
        else:
            raise RuntimeError(f"Invalid value type: {value!r}, this should be a bug, "
                               f"please open an issue to dghs-imgutils.")  # pragma: no cover

        result[key] = value

    return result

from ..data import ImageTyping
from ..generic import yolo_predict

_DEFAULT_MODEL = 'yolov8s_aa11'
_REPO_ID = 'deepghs/booru_yolo'


def detect_with_booru_yolo(image: ImageTyping, model_name: str = _DEFAULT_MODEL,
                           max_infer_size: int = 640, conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
                           conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
    Perform object detection on an image using the Booru YOLO model.
@@ -270,8 +189,6 @@ def detect_with_booru_yolo(image: ImageTyping, model_name: str = _DEFAULT_MODEL,
    :type image: ImageTyping
    :param model_name: Name of the Booru YOLO model to use, defaults to 'yolov8s_aa11'.
    :type model_name: str, optional
    :param max_infer_size: Maximum inference size for image preprocessing, defaults to 640.
    :type max_infer_size: int, optional
    :param conf_threshold: Confidence threshold for detection, defaults to 0.25.
    :type conf_threshold: float, optional
    :param iou_threshold: IOU threshold for non-maximum suppression, defaults to 0.7.
@@ -286,10 +203,10 @@ def detect_with_booru_yolo(image: ImageTyping, model_name: str = _DEFAULT_MODEL,
        >>> for box, label, confidence in detections:
        ...     print(f"Detected {label} with confidence {confidence:.2f} at {box}")
    """
    image = load_image(image, mode='RGB')
    model = _open_booru_yolo_model(model_name)
    labels = _get_booru_yolo_labels(model_name)
    new_image, old_size, new_size = _image_preprocess(image, max_infer_size)
    data = rgb_encode(new_image)[None, ...]
    output, = model.run(['output0'], {'images': data})
    return _data_postprocess(output[0], conf_threshold, iou_threshold, old_size, new_size, labels)
    return yolo_predict(
        image=image,
        repo_id=_REPO_ID,
        model_name=model_name,
        conf_threshold=conf_threshold,
        iou_threshold=iou_threshold,
    )
+2 −3
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ import pytest
from PIL import Image

from imgutils.detect import detect_with_booru_yolo
from imgutils.detect.booru_yolo import _open_booru_yolo_model, _get_booru_yolo_labels
from imgutils.generic.yolo import _open_models_for_repo_id
from ..testings import get_testfile


@@ -11,8 +11,7 @@ def _release_model_after_run():
    try:
        yield
    finally:
        _open_booru_yolo_model.cache_clear()
        _get_booru_yolo_labels.cache_clear()
        _open_models_for_repo_id.cache_clear()


@pytest.fixture()