Unverified Commit 551ccee3 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #114 from deepghs/dev/head

dev(narugo): use the newest head detection model
parents 74d4a2d8 5b4fc1b4
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -4,7 +4,10 @@ from benchmark import BaseBenchmark, create_plot_cli
from imgutils.detect.head import detect_heads, _REPO_ID
from imgutils.generic.yolo import _open_models_for_repo_id

_MODELS = _open_models_for_repo_id(_REPO_ID).model_names
_MODELS = [
    name for name in _open_models_for_repo_id(_REPO_ID).model_names
    if '_v2.0_' in name
]


class HeadDetectBenchmark(BaseBenchmark):
+1135 −782

File changed.

Preview size limit exceeded, changes collapsed.

+18 −18

File changed.

Preview size limit exceeded, changes collapsed.

+23 −9
Original line number Diff line number Diff line
@@ -12,8 +12,8 @@ Overview:
    - Customizable confidence and IoU thresholds
    - Integration with Hugging Face model repository

    The module is based on the `ani_face_detection <https://universe.roboflow.com/linog/ani_face_detection>`_ dataset
    from Roboflow and uses YOLOv8 architecture for object detection.
    The module is based on the `deepghs/anime_head_detection <https://huggingface.co/datasets/deepghs/anime_head_detection>`_
    dataset contributed by our developers and uses YOLOv8/YOLO11 architecture for object detection.

    Example usage and benchmarks are provided in the module overview.

@@ -23,6 +23,7 @@ Overview:
        :align: center

"""
import warnings
from typing import List, Tuple, Optional

from ..data import ImageTyping
@@ -31,8 +32,9 @@ from ..generic import yolo_predict
_REPO_ID = 'deepghs/anime_head_detection'


def detect_heads(image: ImageTyping, level: str = 's', model_name: Optional[str] = None,
                 conf_threshold: float = 0.3, iou_threshold: float = 0.7) \
def detect_heads(image: ImageTyping, level: Optional[str] = None,
                 model_name: Optional[str] = 'head_detect_v2.0_s',
                 conf_threshold: float = 0.4, iou_threshold: float = 0.7) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
    Detect human heads in anime images using YOLOv8 models.
@@ -43,13 +45,15 @@ def detect_heads(image: ImageTyping, level: str = 's', model_name: Optional[str]
    :param image: The input image for head detection. Can be a file path, URL, or image data.
    :type image: ImageTyping

    :param level: The model level to use. 's' for higher accuracy, 'n' for faster speed. Default is 's'.
    :type level: str
    :param level: The model level to use. 's' for higher accuracy, 'n' for faster speed.
                  Default is None (actually equals to 's').
                  This parameter is deprecated and will be removed in future versions.
    :type level: Optional[str]

    :param model_name: Name of the specific YOLO model to use. If not provided, uses default models based on the level.
    :param model_name: Name of the specific YOLO model to use. If not provided, uses 'head_detect_v2.0_s'.
    :type model_name: Optional[str]

    :param conf_threshold: Confidence threshold for detection results. Only detections with confidence above this value are returned. Default is 0.3.
    :param conf_threshold: Confidence threshold for detection results. Only detections with confidence above this value are returned. Default is 0.4.
    :type conf_threshold: float

    :param iou_threshold: IoU (Intersection over Union) threshold for non-maximum suppression. Helps in removing overlapping detections. Default is 0.7.
@@ -75,11 +79,21 @@ def detect_heads(image: ImageTyping, level: str = 's', model_name: Optional[str]
    .. note::

        For visualization of results, you can use the :func:`imgutils.detect.visual.detection_visualize` function.

    .. warning::

        The 'level' parameter is deprecated and will be removed in future versions. Use 'model_name' instead.
    """
    if level:
        warnings.warn(DeprecationWarning(
            'Argument level in function detect_heads is deprecated and will be removed in the future, '
            'please migrate to model_name as soon as possible.'
        ))

    return yolo_predict(
        image=image,
        repo_id=_REPO_ID,
        model_name=model_name or f'head_detect_v0_{level}',
        model_name=model_name or f'head_detect_v0_{level or "s"}',
        conf_threshold=conf_threshold,
        iou_threshold=iou_threshold,
    )
+13 −16
Original line number Diff line number Diff line
@@ -20,17 +20,15 @@ class TestDetectHead:
        detections = detect_heads(get_testfile('genshin_post.jpg'))
        assert len(detections) == 4

        values = []
        for bbox, label, score in detections:
            assert label == 'head'
            values.append((bbox, int(score * 1000) / 1000))

        assert values == pytest.approx([
            ((202, 156, 356, 293), 0.876),
            ((936, 86, 1134, 267), 0.834),
            ((650, 444, 720, 518), 0.778),
            ((461, 247, 536, 330), 0.434),
        ])
        assert detection_similarity(
            detections,
            [
                ((210, 161, 348, 288), 'head', 0.8935408592224121),
                ((462, 250, 531, 328), 'head', 0.8133165836334229),
                ((651, 439, 725, 514), 'head', 0.8114989995956421),
                ((787, 0, 1124, 262), 'head', 0.780591607093811)
            ]
        ) >= 0.9

    def test_detect_heads_none(self):
        assert detect_heads(get_testfile('png_full.png')) == []
@@ -48,7 +46,6 @@ class TestDetectHead:
            ((202, 156, 356, 293), 'head', 0.876),
            ((936, 86, 1134, 267), 'head', 0.834),
            ((650, 444, 720, 518), 'head', 0.778),
            ((461, 247, 536, 330), 'head', 0.434),
        ])
        assert similarity >= 0.85

@@ -60,7 +57,7 @@ class TestDetectHead:
        #            so this expected result is 100% bullshit
        #            just make sure the rtdetr models can be properly inferred
        detections = detect_heads(get_testfile('genshin_post.jpg'), model_name=model_name)
        similarity = detection_similarity(detections, [
            ((780, 9, 1125, 208), 'head', 0.3077814280986786)
        ])
        assert similarity >= 0.85
        assert detections == []
        # similarity = detection_similarity(detections, [
        # ])
        # assert similarity >= 0.85