Commit f88644a2 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): use new model for person

parent 47d39c71
Loading
Loading
Loading
Loading
+8 −6
Original line number Diff line number Diff line
@@ -5,13 +5,14 @@ from imgutils.detect import detect_person


class PersonDetectBenchmark(BaseBenchmark):
    def __init__(self, level):
    def __init__(self, level, plus):
        BaseBenchmark.__init__(self)
        self.level = level
        self.plus = plus

    def load(self):
        from imgutils.detect.person import _open_person_detect_model
        _ = _open_person_detect_model(level=self.level)
        _ = _open_person_detect_model(level=self.level, plus=self.plus)

    def unload(self):
        from imgutils.detect.person import _open_person_detect_model
@@ -19,15 +20,16 @@ class PersonDetectBenchmark(BaseBenchmark):

    def run(self):
        image_file = random.choice(self.all_images)
        _ = detect_person(image_file, level=self.level)
        _ = detect_person(image_file, level=self.level, plus=self.plus)


if __name__ == '__main__':
    create_plot_cli(
        [
            ('person (yolov8s)', PersonDetectBenchmark('s')),
            ('person (yolov8m)', PersonDetectBenchmark('m')),
            ('person (yolov8x)', PersonDetectBenchmark('x')),
            ('person plus (yolov8m)', PersonDetectBenchmark('m', True)),
            ('person (yolov8s)', PersonDetectBenchmark('s', False)),
            ('person (yolov8m)', PersonDetectBenchmark('m', False)),
            ('person (yolov8x)', PersonDetectBenchmark('x', False)),
        ],
        title='Benchmark for Anime Person Detections',
        run_times=10,
+0 −2460

File deleted.

Preview size limit exceeded, changes collapsed.

+5 −2
Original line number Diff line number Diff line
@@ -9,8 +9,11 @@ def _detect(img, **kwargs):

if __name__ == '__main__':
    image_plot(
        (_detect('genshin_post.jpg'), ''),
        (_detect('nian.png'), 'large scale'),
        (_detect('two_bikini_girls.png'), 'closed faces'),
        (_detect('genshin_post.jpg'), 'multiple'),
        (_detect('mostima_post.jpg'), 'anime style'),
        save_as='person_detect.dat.svg',
        columns=1,
        columns=2,
        figsize=(12, 9),
    )
+6 −4
Original line number Diff line number Diff line
@@ -24,14 +24,14 @@ from ..utils import open_onnx_model


@lru_cache()
def _open_person_detect_model(level: str = 's'):
def _open_person_detect_model(level: str = 'm', plus: bool = True):
    return open_onnx_model(hf_hub_download(
        'deepghs/imgutils-models',
        f'person_detect/person_detect_best_{level}.onnx'
        f'person_detect/person_detect_{"plus_" if plus else ""}best_{level}.onnx'
    ))


def detect_person(image: ImageTyping, level: str = 's', max_infer_size=1216,
def detect_person(image: ImageTyping, level: str = 'm', plus: bool = True, max_infer_size=1216,
                  conf_threshold: float = 0.3, iou_threshold: float = 0.5):
    """
    Overview:
@@ -41,6 +41,8 @@ def detect_person(image: ImageTyping, level: str = 's', max_infer_size=1216,
    :param level: The model level being used can be either `s`, `m` or `x`.
        The `s` model runs faster with smaller system overhead, while the `m` model achieves higher accuracy.
        The default value is `s`.
    :param plus: Use plus model. Default is ``True``. This argument is not recommended to use ``False`` unless
        you know what this means.
    :param max_infer_size: The maximum image size used for model inference, if the image size exceeds this limit,
        the image will be resized and used for inference. The default value is `1216` pixels.
    :param conf_threshold: The confidence threshold, only detection results with confidence scores above
@@ -72,5 +74,5 @@ def detect_person(image: ImageTyping, level: str = 's', max_infer_size=1216,
    new_image, old_size, new_size = _image_preprocess(image, max_infer_size)

    data = rgb_encode(new_image)[None, ...]
    output, = _open_person_detect_model(level).run(['output0'], {'images': data})
    output, = _open_person_detect_model(level, plus).run(['output0'], {'images': data})
    return _data_postprocess(output[0], conf_threshold, iou_threshold, old_size, new_size, ['person'])