Commit 93576502 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): try add that for yolo models

parent 96d10def
Loading
Loading
Loading
Loading
+54 −24
Original line number Diff line number Diff line
@@ -19,12 +19,15 @@ from threading import Lock
from typing import List, Optional, Tuple, Union

import numpy as np
import requests
from PIL import Image
from hbutils.color import rnd_colors
from hfutils.operate import get_hf_client, get_hf_fs
from hbutils.design import SingletonMark
from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download
from huggingface_hub.errors import OfflineModeIsEnabled, EntryNotFoundError

from ..data import load_image, rgb_encode, ImageTyping
from ..utils import open_onnx_model, ts_lru_cache
@@ -454,6 +457,9 @@ def _safe_eval_names_str(names_str):
    return result


_OFFLINE = SingletonMark('OFFLINE')


class YOLOModel:
    """
    A class to manage YOLO models from a Hugging Face repository.
@@ -508,6 +514,7 @@ class YOLOModel:
        """
        with self._global_lock:
            if self._model_names is None:
                try:
                    hf_fs = HfFileSystem(token=self._get_hf_token())
                    self._model_names = [
                        hf_normpath(os.path.dirname(os.path.relpath(item, self.repo_id)))
@@ -517,6 +524,12 @@ class YOLOModel:
                            filename='*/model.onnx',
                        ))
                    ]
                except (
                        requests.exceptions.ConnectionError,
                        requests.exceptions.Timeout,
                        OfflineModeIsEnabled,
                ):
                    self._model_names = _OFFLINE

        return self._model_names

@@ -528,7 +541,10 @@ class YOLOModel:
        :type model_name: str
        :raises ValueError: If the model name is not found in the repository.
        """
        if model_name not in self.model_names:
        model_list = self.model_names
        if model_list is _OFFLINE:
            return  # do not check when in offline mode
        if model_name not in model_list:
            raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
                             f'models {self.model_names!r} are available.')

@@ -545,8 +561,9 @@ class YOLOModel:
            if model_name not in self._models:
                self._check_model_name(model_name)
                model = open_onnx_model(hf_hub_download(
                    self.repo_id,
                    f'{model_name}/model.onnx',
                    repo_id=self.repo_id,
                    repo_type='model',
                    filename=f'{model_name}/model.onnx',
                    token=self._get_hf_token(),
                ))
                model_metadata = model.get_modelmeta()
@@ -564,17 +581,20 @@ class YOLOModel:
    def _get_model_type(self, model_name: str):
        with self._model_lock:
            if model_name not in self._model_types:
                hf_fs = get_hf_fs(hf_token=self._get_hf_token())
                fs_path = hf_fs_path(
                try:
                    model_type_file = hf_hub_download(
                        repo_id=self.repo_id,
                        repo_type='model',
                        filename=f'{model_name}/model_type.json',
                        revision='main',
                        token=self._get_hf_token()
                    )
                if hf_fs.exists(fs_path):
                    model_type = json.loads(hf_fs.read_text(fs_path))['model_type']
                else:
                except (EntryNotFoundError,):
                    model_type = 'yolo'
                else:
                    with open(model_type_file, 'r') as f:
                        model_type = json.load(f)['model_type']

                self._model_types[model_name] = model_type

        return self._model_types[model_name]
@@ -643,7 +663,9 @@ class YOLOModel:
        This method removes all cached models and their associated metadata from memory.
        It's useful for freeing up memory or ensuring that the latest versions of models are loaded.
        """
        self._model_names = None
        self._models.clear()
        self._model_types.clear()

    def make_ui(self, default_model_name: Optional[str] = None,
                default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7):
@@ -671,6 +693,10 @@ class YOLOModel:
        """
        _check_gradio_env()
        model_list = self.model_names
        if model_list is _OFFLINE and not default_model_name:
            raise EnvironmentError('You are in OFFLINE mode, '
                                   'you must assign a default model name to make this ui usable.')

        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
            selected_model_name, selected_time = None, None
@@ -711,7 +737,11 @@ class YOLOModel:
            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    if model_list is not _OFFLINE:
                        gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                    else:
                        gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model',
                                               interactive=False)
                    gr_allow_dynamic = gr.Checkbox(value=False, label='Allow Dynamic Size')
                with gr.Row():
                    gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold')
+70 −0
Original line number Diff line number Diff line
from unittest.mock import patch

import pytest
from huggingface_hub import configure_http_backend
from huggingface_hub.utils import reset_sessions

from imgutils.detect import detection_similarity
from imgutils.generic.yolo import _open_models_for_repo_id, yolo_predict
from test.testings import get_testfile


@pytest.fixture(scope='module', autouse=True)
def _release_model_after_run():
    try:
        yield
    finally:
        pass


@pytest.fixture(scope='function', autouse=True)
def _clean_session():
    reset_sessions()
    _open_models_for_repo_id.cache_clear()
    print('clean session')
    try:
        yield
    finally:
        reset_sessions()
        _open_models_for_repo_id.cache_clear()
        print('clean session')


@pytest.mark.unittest
class TestGenericYOLO:
    def test_detect_faces(self):
        detection = yolo_predict(
            get_testfile('genshin_post.jpg'),
            repo_id='deepghs/anime_face_detection',
            model_name='face_detect_v1.4_s',
        )
        similarity = detection_similarity(detection, [
            ((966, 142, 1085, 261), 'face', 0.850458025932312),
            ((247, 209, 330, 288), 'face', 0.8288277387619019),
            ((661, 467, 706, 512), 'face', 0.754958987236023),
            ((481, 282, 522, 325), 'face', 0.7148504257202148)
        ])
        assert similarity >= 0.9

    def test_detect_faces_none(self):
        assert yolo_predict(
            get_testfile('png_full.png'),
            repo_id='deepghs/anime_face_detection',
            model_name='face_detect_v1.4_s',
        ) == []

    @patch("huggingface_hub.constants.HF_HUB_OFFLINE", True)
    def test_detect_faces_with_offline_mode(self):
        configure_http_backend()
        detection = yolo_predict(
            get_testfile('genshin_post.jpg'),
            repo_id='deepghs/anime_face_detection',
            model_name='face_detect_v1.4_s',
        )
        similarity = detection_similarity(detection, [
            ((966, 142, 1085, 261), 'face', 0.850458025932312),
            ((247, 209, 330, 288), 'face', 0.8288277387619019),
            ((661, 467, 706, 512), 'face', 0.754958987236023),
            ((481, 282, 522, 325), 'face', 0.7148504257202148)
        ])
        assert similarity >= 0.9