Loading imgutils/generic/yolo.py +54 −24 Original line number Diff line number Diff line Loading @@ -19,12 +19,15 @@ from threading import Lock from typing import List, Optional, Tuple, Union import numpy as np import requests from PIL import Image from hbutils.color import rnd_colors from hfutils.operate import get_hf_client, get_hf_fs from hbutils.design import SingletonMark from hfutils.operate import get_hf_client from hfutils.repository import hf_hub_repo_url from hfutils.utils import hf_fs_path, hf_normpath from huggingface_hub import HfFileSystem, hf_hub_download from huggingface_hub.errors import OfflineModeIsEnabled, EntryNotFoundError from ..data import load_image, rgb_encode, ImageTyping from ..utils import open_onnx_model, ts_lru_cache Loading Loading @@ -454,6 +457,9 @@ def _safe_eval_names_str(names_str): return result _OFFLINE = SingletonMark('OFFLINE') class YOLOModel: """ A class to manage YOLO models from a Hugging Face repository. Loading Loading @@ -508,6 +514,7 @@ class YOLOModel: """ with self._global_lock: if self._model_names is None: try: hf_fs = HfFileSystem(token=self._get_hf_token()) self._model_names = [ hf_normpath(os.path.dirname(os.path.relpath(item, self.repo_id))) Loading @@ -517,6 +524,12 @@ class YOLOModel: filename='*/model.onnx', )) ] except ( requests.exceptions.ConnectionError, requests.exceptions.Timeout, OfflineModeIsEnabled, ): self._model_names = _OFFLINE return self._model_names Loading @@ -528,7 +541,10 @@ class YOLOModel: :type model_name: str :raises ValueError: If the model name is not found in the repository. """ if model_name not in self.model_names: model_list = self.model_names if model_list is _OFFLINE: return # do not check when in offline mode if model_name not in model_list: raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, ' f'models {self.model_names!r} are available.') Loading @@ -545,8 +561,9 @@ class YOLOModel: if model_name not in self._models: self._check_model_name(model_name) model = open_onnx_model(hf_hub_download( self.repo_id, f'{model_name}/model.onnx', repo_id=self.repo_id, repo_type='model', filename=f'{model_name}/model.onnx', token=self._get_hf_token(), )) model_metadata = model.get_modelmeta() Loading @@ -564,17 +581,20 @@ class YOLOModel: def _get_model_type(self, model_name: str): with self._model_lock: if model_name not in self._model_types: hf_fs = get_hf_fs(hf_token=self._get_hf_token()) fs_path = hf_fs_path( try: model_type_file = hf_hub_download( repo_id=self.repo_id, repo_type='model', filename=f'{model_name}/model_type.json', revision='main', token=self._get_hf_token() ) if hf_fs.exists(fs_path): model_type = json.loads(hf_fs.read_text(fs_path))['model_type'] else: except (EntryNotFoundError,): model_type = 'yolo' else: with open(model_type_file, 'r') as f: model_type = json.load(f)['model_type'] self._model_types[model_name] = model_type return self._model_types[model_name] Loading Loading @@ -643,7 +663,9 @@ class YOLOModel: This method removes all cached models and their associated metadata from memory. It's useful for freeing up memory or ensuring that the latest versions of models are loaded. """ self._model_names = None self._models.clear() self._model_types.clear() def make_ui(self, default_model_name: Optional[str] = None, default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7): Loading Loading @@ -671,6 +693,10 @@ class YOLOModel: """ _check_gradio_env() model_list = self.model_names if model_list is _OFFLINE and not default_model_name: raise EnvironmentError('You are in OFFLINE mode, ' 'you must assign a default model name to make this ui usable.') if not default_model_name: hf_client = get_hf_client(hf_token=self._get_hf_token()) selected_model_name, selected_time = None, None Loading Loading @@ -711,7 +737,11 @@ class YOLOModel: with gr.Column(): gr_input_image = gr.Image(type='pil', label='Original Image') with gr.Row(): if model_list is not _OFFLINE: gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model') else: gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model', interactive=False) gr_allow_dynamic = gr.Checkbox(value=False, label='Allow Dynamic Size') with gr.Row(): gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold') Loading test/generic/test_yolo.py 0 → 100644 +70 −0 Original line number Diff line number Diff line from unittest.mock import patch import pytest from huggingface_hub import configure_http_backend from huggingface_hub.utils import reset_sessions from imgutils.detect import detection_similarity from imgutils.generic.yolo import _open_models_for_repo_id, yolo_predict from test.testings import get_testfile @pytest.fixture(scope='module', autouse=True) def _release_model_after_run(): try: yield finally: pass @pytest.fixture(scope='function', autouse=True) def _clean_session(): reset_sessions() _open_models_for_repo_id.cache_clear() print('clean session') try: yield finally: reset_sessions() _open_models_for_repo_id.cache_clear() print('clean session') @pytest.mark.unittest class TestGenericYOLO: def test_detect_faces(self): detection = yolo_predict( get_testfile('genshin_post.jpg'), repo_id='deepghs/anime_face_detection', model_name='face_detect_v1.4_s', ) similarity = detection_similarity(detection, [ ((966, 142, 1085, 261), 'face', 0.850458025932312), ((247, 209, 330, 288), 'face', 0.8288277387619019), ((661, 467, 706, 512), 'face', 0.754958987236023), ((481, 282, 522, 325), 'face', 0.7148504257202148) ]) assert similarity >= 0.9 def test_detect_faces_none(self): assert yolo_predict( get_testfile('png_full.png'), repo_id='deepghs/anime_face_detection', model_name='face_detect_v1.4_s', ) == [] @patch("huggingface_hub.constants.HF_HUB_OFFLINE", True) def test_detect_faces_with_offline_mode(self): configure_http_backend() detection = yolo_predict( get_testfile('genshin_post.jpg'), repo_id='deepghs/anime_face_detection', model_name='face_detect_v1.4_s', ) similarity = detection_similarity(detection, [ ((966, 142, 1085, 261), 'face', 0.850458025932312), ((247, 209, 330, 288), 'face', 0.8288277387619019), ((661, 467, 706, 512), 'face', 0.754958987236023), ((481, 282, 522, 325), 'face', 0.7148504257202148) ]) assert similarity >= 0.9 Loading
imgutils/generic/yolo.py +54 −24 Original line number Diff line number Diff line Loading @@ -19,12 +19,15 @@ from threading import Lock from typing import List, Optional, Tuple, Union import numpy as np import requests from PIL import Image from hbutils.color import rnd_colors from hfutils.operate import get_hf_client, get_hf_fs from hbutils.design import SingletonMark from hfutils.operate import get_hf_client from hfutils.repository import hf_hub_repo_url from hfutils.utils import hf_fs_path, hf_normpath from huggingface_hub import HfFileSystem, hf_hub_download from huggingface_hub.errors import OfflineModeIsEnabled, EntryNotFoundError from ..data import load_image, rgb_encode, ImageTyping from ..utils import open_onnx_model, ts_lru_cache Loading Loading @@ -454,6 +457,9 @@ def _safe_eval_names_str(names_str): return result _OFFLINE = SingletonMark('OFFLINE') class YOLOModel: """ A class to manage YOLO models from a Hugging Face repository. Loading Loading @@ -508,6 +514,7 @@ class YOLOModel: """ with self._global_lock: if self._model_names is None: try: hf_fs = HfFileSystem(token=self._get_hf_token()) self._model_names = [ hf_normpath(os.path.dirname(os.path.relpath(item, self.repo_id))) Loading @@ -517,6 +524,12 @@ class YOLOModel: filename='*/model.onnx', )) ] except ( requests.exceptions.ConnectionError, requests.exceptions.Timeout, OfflineModeIsEnabled, ): self._model_names = _OFFLINE return self._model_names Loading @@ -528,7 +541,10 @@ class YOLOModel: :type model_name: str :raises ValueError: If the model name is not found in the repository. """ if model_name not in self.model_names: model_list = self.model_names if model_list is _OFFLINE: return # do not check when in offline mode if model_name not in model_list: raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, ' f'models {self.model_names!r} are available.') Loading @@ -545,8 +561,9 @@ class YOLOModel: if model_name not in self._models: self._check_model_name(model_name) model = open_onnx_model(hf_hub_download( self.repo_id, f'{model_name}/model.onnx', repo_id=self.repo_id, repo_type='model', filename=f'{model_name}/model.onnx', token=self._get_hf_token(), )) model_metadata = model.get_modelmeta() Loading @@ -564,17 +581,20 @@ class YOLOModel: def _get_model_type(self, model_name: str): with self._model_lock: if model_name not in self._model_types: hf_fs = get_hf_fs(hf_token=self._get_hf_token()) fs_path = hf_fs_path( try: model_type_file = hf_hub_download( repo_id=self.repo_id, repo_type='model', filename=f'{model_name}/model_type.json', revision='main', token=self._get_hf_token() ) if hf_fs.exists(fs_path): model_type = json.loads(hf_fs.read_text(fs_path))['model_type'] else: except (EntryNotFoundError,): model_type = 'yolo' else: with open(model_type_file, 'r') as f: model_type = json.load(f)['model_type'] self._model_types[model_name] = model_type return self._model_types[model_name] Loading Loading @@ -643,7 +663,9 @@ class YOLOModel: This method removes all cached models and their associated metadata from memory. It's useful for freeing up memory or ensuring that the latest versions of models are loaded. """ self._model_names = None self._models.clear() self._model_types.clear() def make_ui(self, default_model_name: Optional[str] = None, default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7): Loading Loading @@ -671,6 +693,10 @@ class YOLOModel: """ _check_gradio_env() model_list = self.model_names if model_list is _OFFLINE and not default_model_name: raise EnvironmentError('You are in OFFLINE mode, ' 'you must assign a default model name to make this ui usable.') if not default_model_name: hf_client = get_hf_client(hf_token=self._get_hf_token()) selected_model_name, selected_time = None, None Loading Loading @@ -711,7 +737,11 @@ class YOLOModel: with gr.Column(): gr_input_image = gr.Image(type='pil', label='Original Image') with gr.Row(): if model_list is not _OFFLINE: gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model') else: gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model', interactive=False) gr_allow_dynamic = gr.Checkbox(value=False, label='Allow Dynamic Size') with gr.Row(): gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold') Loading
test/generic/test_yolo.py 0 → 100644 +70 −0 Original line number Diff line number Diff line from unittest.mock import patch import pytest from huggingface_hub import configure_http_backend from huggingface_hub.utils import reset_sessions from imgutils.detect import detection_similarity from imgutils.generic.yolo import _open_models_for_repo_id, yolo_predict from test.testings import get_testfile @pytest.fixture(scope='module', autouse=True) def _release_model_after_run(): try: yield finally: pass @pytest.fixture(scope='function', autouse=True) def _clean_session(): reset_sessions() _open_models_for_repo_id.cache_clear() print('clean session') try: yield finally: reset_sessions() _open_models_for_repo_id.cache_clear() print('clean session') @pytest.mark.unittest class TestGenericYOLO: def test_detect_faces(self): detection = yolo_predict( get_testfile('genshin_post.jpg'), repo_id='deepghs/anime_face_detection', model_name='face_detect_v1.4_s', ) similarity = detection_similarity(detection, [ ((966, 142, 1085, 261), 'face', 0.850458025932312), ((247, 209, 330, 288), 'face', 0.8288277387619019), ((661, 467, 706, 512), 'face', 0.754958987236023), ((481, 282, 522, 325), 'face', 0.7148504257202148) ]) assert similarity >= 0.9 def test_detect_faces_none(self): assert yolo_predict( get_testfile('png_full.png'), repo_id='deepghs/anime_face_detection', model_name='face_detect_v1.4_s', ) == [] @patch("huggingface_hub.constants.HF_HUB_OFFLINE", True) def test_detect_faces_with_offline_mode(self): configure_http_backend() detection = yolo_predict( get_testfile('genshin_post.jpg'), repo_id='deepghs/anime_face_detection', model_name='face_detect_v1.4_s', ) similarity = detection_similarity(detection, [ ((966, 142, 1085, 261), 'face', 0.850458025932312), ((247, 209, 330, 288), 'face', 0.8288277387619019), ((661, 467, 706, 512), 'face', 0.754958987236023), ((481, 282, 522, 325), 'face', 0.7148504257202148) ]) assert similarity >= 0.9