Commit 593d20c7 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): use the newest head models

parent 5f14d6af
Loading
Loading
Loading
Loading
+12 −4
Original line number Diff line number Diff line
@@ -12,8 +12,8 @@ Overview:
    - Customizable confidence and IoU thresholds
    - Integration with Hugging Face model repository

    The module is based on the `ani_face_detection <https://universe.roboflow.com/linog/ani_face_detection>`_ dataset
    from Roboflow and uses YOLOv8 architecture for object detection.
    The module is based on the `deepghs/anime_head_detection <https://huggingface.co/datasets/deepghs/anime_head_detection>`_
    dataset contributed by our developers and uses YOLOv8/YOLO11 architecture for object detection.

    Example usage and benchmarks are provided in the module overview.

@@ -23,6 +23,7 @@ Overview:
        :align: center

"""
import warnings
from typing import List, Tuple, Optional

from ..data import ImageTyping
@@ -31,8 +32,9 @@ from ..generic import yolo_predict
_REPO_ID = 'deepghs/anime_head_detection'


def detect_heads(image: ImageTyping, level: str = 's', model_name: Optional[str] = None,
                 conf_threshold: float = 0.3, iou_threshold: float = 0.7) \
def detect_heads(image: ImageTyping, level: Optional[str] = None,
                 model_name: Optional[str] = 'head_detect_v2.0_s',
                 conf_threshold: float = 0.4, iou_threshold: float = 0.7) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
    Detect human heads in anime images using YOLOv8 models.
@@ -76,6 +78,12 @@ def detect_heads(image: ImageTyping, level: str = 's', model_name: Optional[str]

        For visualization of results, you can use the :func:`imgutils.detect.visual.detection_visualize` function.
    """
    if level:
        warnings.warn(DeprecationWarning(
            'Argument level in function detect_heads is deprecated and will be removed in the future, '
            'please migrate to model_name as soon as possible.'
        ))

    return yolo_predict(
        image=image,
        repo_id=_REPO_ID,
+13 −16
Original line number Diff line number Diff line
@@ -20,17 +20,15 @@ class TestDetectHead:
        detections = detect_heads(get_testfile('genshin_post.jpg'))
        assert len(detections) == 4

        values = []
        for bbox, label, score in detections:
            assert label == 'head'
            values.append((bbox, int(score * 1000) / 1000))

        assert values == pytest.approx([
            ((202, 156, 356, 293), 0.876),
            ((936, 86, 1134, 267), 0.834),
            ((650, 444, 720, 518), 0.778),
            ((461, 247, 536, 330), 0.434),
        ])
        assert detection_similarity(
            detections,
            [
                ((210, 161, 348, 288), 'head', 0.8935408592224121),
                ((462, 250, 531, 328), 'head', 0.8133165836334229),
                ((651, 439, 725, 514), 'head', 0.8114989995956421),
                ((787, 0, 1124, 262), 'head', 0.780591607093811)
            ]
        ) >= 0.9

    def test_detect_heads_none(self):
        assert detect_heads(get_testfile('png_full.png')) == []
@@ -48,7 +46,6 @@ class TestDetectHead:
            ((202, 156, 356, 293), 'head', 0.876),
            ((936, 86, 1134, 267), 'head', 0.834),
            ((650, 444, 720, 518), 'head', 0.778),
            ((461, 247, 536, 330), 'head', 0.434),
        ])
        assert similarity >= 0.85

@@ -60,7 +57,7 @@ class TestDetectHead:
        #            so this expected result is 100% bullshit
        #            just make sure the rtdetr models can be properly inferred
        detections = detect_heads(get_testfile('genshin_post.jpg'), model_name=model_name)
        similarity = detection_similarity(detections, [
            ((780, 9, 1125, 208), 'head', 0.3077814280986786)
        ])
        assert similarity >= 0.85
        assert detections == []
        # similarity = detection_similarity(detections, [
        # ])
        # assert similarity >= 0.85