Loading imgutils/generic/yolo.py +36 −10 Original line number Diff line number Diff line Loading @@ -13,11 +13,14 @@ The module supports various image input types and allows customization of confid import ast import json import math import os import threading from collections import defaultdict from contextlib import contextmanager from threading import Lock from typing import List, Optional, Tuple, Union import math import numpy as np import requests from PIL import Image Loading Loading @@ -54,6 +57,27 @@ def _check_gradio_env(): f'Please install it with `pip install dghs-imgutils[demo]`.') _MODEL_LOAD_LOCKS = defaultdict(Lock) _G_ML_LOCK = Lock() @contextmanager def _model_load_lock(): """ Context manager for thread-safe model loading operations. This context manager ensures that model loading operations are thread-safe by using process-specific locks. It prevents concurrent model loading operations which could lead to race conditions. :yields: None """ with _G_ML_LOCK: lock = _MODEL_LOAD_LOCKS[os.getpid()] with lock: yield def _v_fix(v): """ Round and convert a float value to an integer. Loading Loading @@ -493,7 +517,7 @@ class YOLOModel: self._model_types = {} self._hf_token = hf_token self._global_lock = Lock() self._model_lock = Lock() self._model_meta_lock = Lock() def _get_hf_token(self) -> Optional[str]: """ Loading Loading @@ -567,8 +591,9 @@ class YOLOModel: :return: Tuple containing the ONNX model, maximum inference size, and labels. :rtype: tuple """ with self._model_lock: if model_name not in self._models: cache_key = os.getpid(), threading.get_ident(), model_name with _model_load_lock(): if cache_key not in self._models: self._check_model_name(model_name) model = open_onnx_model(hf_hub_download( repo_id=self.repo_id, Loading @@ -584,12 +609,12 @@ class YOLOModel: max_infer_size = 640 names_map = _safe_eval_names_str(model_metadata.custom_metadata_map['names']) labels = [names_map[i] for i in range(len(names_map))] self._models[model_name] = (model, max_infer_size, labels) self._models[cache_key] = (model, max_infer_size, labels, Lock()) return self._models[model_name] return self._models[cache_key] def _get_model_type(self, model_name: str): with self._model_lock: with self._model_meta_lock: if model_name not in self._model_types: try: model_type_file = hf_hub_download( Loading Loading @@ -639,10 +664,11 @@ class YOLOModel: >>> print(detections[0]) # First detection ((100, 200, 300, 400), 'person', 0.95) """ model, max_infer_size, labels = self._open_model(model_name) model, max_infer_size, labels, exec_lock = self._open_model(model_name) image = load_image(image, mode='RGB') new_image, old_size, new_size = _image_preprocess(image, max_infer_size, allow_dynamic=allow_dynamic) data = rgb_encode(new_image)[None, ...] with exec_lock: # make sure for each session, its execution should be linear output, = model.run(['output0'], {'images': data}) model_type = self._get_model_type(model_name=model_name) if model_type == 'yolo': Loading Loading @@ -732,7 +758,7 @@ class YOLOModel: iou_threshold: float = 0.7, score_threshold: float = 0.25, allow_dynamic: bool = False) \ -> gr.AnnotatedImage: _, _, labels = self._open_model(model_name=model_name) _, _, labels, _ = self._open_model(model_name=model_name) _colors = list(map(str, rnd_colors(len(labels)))) _color_map = dict(zip(labels, _colors)) return gr.AnnotatedImage( Loading imgutils/utils/cache.py +61 −6 Original line number Diff line number Diff line Loading @@ -13,14 +13,49 @@ Usage: ... # Some expensive computation ... return x + y """ import os import threading from collections import defaultdict from functools import lru_cache, wraps from typing import Literal __all__ = ['ts_lru_cache'] LevelTyping = Literal['global', 'process', 'thread'] def _get_context_key(level: LevelTyping = 'global'): """ Get a context key based on the specified caching level. :param level: The caching level to use. Can be 'global', 'process', or 'thread'. :type level: LevelTyping :return: A context key appropriate for the specified level. :rtype: tuple or None :raises ValueError: If an invalid cache level is specified. .. note:: The function returns: def ts_lru_cache(**options): - None for 'global' level - Process ID for 'process' level - (Process ID, Thread ID) tuple for 'thread' level """ if level == 'global': return None elif level == 'process': return os.getpid() elif level == 'thread': return os.getpid(), threading.get_ident() else: raise ValueError(f'Invalid cache level, ' f'\'global\', \'process\' or \'thread\' expected but {level!r} found.') def ts_lru_cache(level: LevelTyping = 'global', **options): """ A thread-safe version of the lru_cache decorator. Loading @@ -28,23 +63,37 @@ def ts_lru_cache(**options): thread-safety in multithreaded environments. It maintains the same interface as the built-in lru_cache, allowing you to specify options like maxsize. :param level: The caching level ('global', 'process', or 'thread'). :type level: LevelTyping :param options: Keyword arguments to be passed to the underlying lru_cache. :type options: dict :return: A thread-safe cached version of the decorated function. :rtype: function :Example: >>> @ts_lru_cache(maxsize=100) :example: >>> @ts_lru_cache(level='thread', maxsize=100) >>> def my_function(x, y): ... # Function implementation ... return x + y .. note:: The decorator provides three levels of caching: - global: Single cache shared across all processes and threads - process: Separate cache for each process - thread: Separate cache for each thread .. note:: While this decorator ensures thread-safety, it may introduce some overhead due to lock acquisition. Use it when thread-safety is more critical than maximum performance in multithreaded scenarios. .. note:: The decorator preserves the cache_info() and cache_clear() methods from the original lru_cache implementation. """ _ = _get_context_key(level) def _decorator(func): """ Loading @@ -56,19 +105,22 @@ def ts_lru_cache(**options): :return: The wrapped function with thread-safe caching. :rtype: function """ @lru_cache(**options) @wraps(func) def _cached_func(*args, **kwargs): def _cached_func(*args, __context_key=None, **kwargs): """ Cached version of the original function. :param args: Positional arguments to be passed to the original function. :param __context_key: Internal context key for cache separation. :param kwargs: Keyword arguments to be passed to the original function. :return: The result of the original function call. """ return func(*args, **kwargs) lock_pool = defaultdict(threading.Lock) lock = threading.Lock() @wraps(_cached_func) Loading @@ -84,8 +136,11 @@ def ts_lru_cache(**options): :return: The result of the cached function call. """ context_key = _get_context_key(level=level) with lock: return _cached_func(*args, **kwargs) _context_lock = lock_pool[context_key] with _context_lock: return _cached_func(*args, __context_key=context_key, **kwargs) # Preserve cache_info and cache_clear methods if they exist if hasattr(_cached_func, 'cache_info'): Loading test/utils/test_cache.py +49 −0 Original line number Diff line number Diff line import threading import time from concurrent.futures import ThreadPoolExecutor from unittest.mock import patch import pytest from imgutils.utils import ts_lru_cache from imgutils.utils.cache import _get_context_key @pytest.fixture Loading Loading @@ -124,3 +127,49 @@ class TestTsLruCache: assert func(2) == 4 # This should calculate again assert call_count == 6 @pytest.fixture def reset_threading_ident(): original_get_ident = threading.get_ident try: yield finally: threading.get_ident = original_get_ident @pytest.fixture def mock_os_getpid(): with patch('os.getpid', return_value=12345) as mock_getpid: yield mock_getpid @pytest.fixture def mock_threading_get_ident(): with patch('threading.get_ident', return_value=67890) as mock_get_ident: yield mock_get_ident @pytest.mark.unittest class TestGetContextKey: def test_global_level(self): assert _get_context_key('global') is None def test_process_level(self, mock_os_getpid): assert _get_context_key('process') == 12345 mock_os_getpid.assert_called_once() def test_thread_level(self, mock_os_getpid, mock_threading_get_ident): assert _get_context_key('thread') == (12345, 67890) mock_os_getpid.assert_called_once() mock_threading_get_ident.assert_called_once() def test_invalid_level(self): with pytest.raises(ValueError) as excinfo: _get_context_key('invalid') assert "Invalid cache level" in str(excinfo.value) assert "'global', 'process' or 'thread' expected but 'invalid' found" in str(excinfo.value) def test_default_level(self): # Test that the default level is 'global' assert _get_context_key() is None Loading
imgutils/generic/yolo.py +36 −10 Original line number Diff line number Diff line Loading @@ -13,11 +13,14 @@ The module supports various image input types and allows customization of confid import ast import json import math import os import threading from collections import defaultdict from contextlib import contextmanager from threading import Lock from typing import List, Optional, Tuple, Union import math import numpy as np import requests from PIL import Image Loading Loading @@ -54,6 +57,27 @@ def _check_gradio_env(): f'Please install it with `pip install dghs-imgutils[demo]`.') _MODEL_LOAD_LOCKS = defaultdict(Lock) _G_ML_LOCK = Lock() @contextmanager def _model_load_lock(): """ Context manager for thread-safe model loading operations. This context manager ensures that model loading operations are thread-safe by using process-specific locks. It prevents concurrent model loading operations which could lead to race conditions. :yields: None """ with _G_ML_LOCK: lock = _MODEL_LOAD_LOCKS[os.getpid()] with lock: yield def _v_fix(v): """ Round and convert a float value to an integer. Loading Loading @@ -493,7 +517,7 @@ class YOLOModel: self._model_types = {} self._hf_token = hf_token self._global_lock = Lock() self._model_lock = Lock() self._model_meta_lock = Lock() def _get_hf_token(self) -> Optional[str]: """ Loading Loading @@ -567,8 +591,9 @@ class YOLOModel: :return: Tuple containing the ONNX model, maximum inference size, and labels. :rtype: tuple """ with self._model_lock: if model_name not in self._models: cache_key = os.getpid(), threading.get_ident(), model_name with _model_load_lock(): if cache_key not in self._models: self._check_model_name(model_name) model = open_onnx_model(hf_hub_download( repo_id=self.repo_id, Loading @@ -584,12 +609,12 @@ class YOLOModel: max_infer_size = 640 names_map = _safe_eval_names_str(model_metadata.custom_metadata_map['names']) labels = [names_map[i] for i in range(len(names_map))] self._models[model_name] = (model, max_infer_size, labels) self._models[cache_key] = (model, max_infer_size, labels, Lock()) return self._models[model_name] return self._models[cache_key] def _get_model_type(self, model_name: str): with self._model_lock: with self._model_meta_lock: if model_name not in self._model_types: try: model_type_file = hf_hub_download( Loading Loading @@ -639,10 +664,11 @@ class YOLOModel: >>> print(detections[0]) # First detection ((100, 200, 300, 400), 'person', 0.95) """ model, max_infer_size, labels = self._open_model(model_name) model, max_infer_size, labels, exec_lock = self._open_model(model_name) image = load_image(image, mode='RGB') new_image, old_size, new_size = _image_preprocess(image, max_infer_size, allow_dynamic=allow_dynamic) data = rgb_encode(new_image)[None, ...] with exec_lock: # make sure for each session, its execution should be linear output, = model.run(['output0'], {'images': data}) model_type = self._get_model_type(model_name=model_name) if model_type == 'yolo': Loading Loading @@ -732,7 +758,7 @@ class YOLOModel: iou_threshold: float = 0.7, score_threshold: float = 0.25, allow_dynamic: bool = False) \ -> gr.AnnotatedImage: _, _, labels = self._open_model(model_name=model_name) _, _, labels, _ = self._open_model(model_name=model_name) _colors = list(map(str, rnd_colors(len(labels)))) _color_map = dict(zip(labels, _colors)) return gr.AnnotatedImage( Loading
imgutils/utils/cache.py +61 −6 Original line number Diff line number Diff line Loading @@ -13,14 +13,49 @@ Usage: ... # Some expensive computation ... return x + y """ import os import threading from collections import defaultdict from functools import lru_cache, wraps from typing import Literal __all__ = ['ts_lru_cache'] LevelTyping = Literal['global', 'process', 'thread'] def _get_context_key(level: LevelTyping = 'global'): """ Get a context key based on the specified caching level. :param level: The caching level to use. Can be 'global', 'process', or 'thread'. :type level: LevelTyping :return: A context key appropriate for the specified level. :rtype: tuple or None :raises ValueError: If an invalid cache level is specified. .. note:: The function returns: def ts_lru_cache(**options): - None for 'global' level - Process ID for 'process' level - (Process ID, Thread ID) tuple for 'thread' level """ if level == 'global': return None elif level == 'process': return os.getpid() elif level == 'thread': return os.getpid(), threading.get_ident() else: raise ValueError(f'Invalid cache level, ' f'\'global\', \'process\' or \'thread\' expected but {level!r} found.') def ts_lru_cache(level: LevelTyping = 'global', **options): """ A thread-safe version of the lru_cache decorator. Loading @@ -28,23 +63,37 @@ def ts_lru_cache(**options): thread-safety in multithreaded environments. It maintains the same interface as the built-in lru_cache, allowing you to specify options like maxsize. :param level: The caching level ('global', 'process', or 'thread'). :type level: LevelTyping :param options: Keyword arguments to be passed to the underlying lru_cache. :type options: dict :return: A thread-safe cached version of the decorated function. :rtype: function :Example: >>> @ts_lru_cache(maxsize=100) :example: >>> @ts_lru_cache(level='thread', maxsize=100) >>> def my_function(x, y): ... # Function implementation ... return x + y .. note:: The decorator provides three levels of caching: - global: Single cache shared across all processes and threads - process: Separate cache for each process - thread: Separate cache for each thread .. note:: While this decorator ensures thread-safety, it may introduce some overhead due to lock acquisition. Use it when thread-safety is more critical than maximum performance in multithreaded scenarios. .. note:: The decorator preserves the cache_info() and cache_clear() methods from the original lru_cache implementation. """ _ = _get_context_key(level) def _decorator(func): """ Loading @@ -56,19 +105,22 @@ def ts_lru_cache(**options): :return: The wrapped function with thread-safe caching. :rtype: function """ @lru_cache(**options) @wraps(func) def _cached_func(*args, **kwargs): def _cached_func(*args, __context_key=None, **kwargs): """ Cached version of the original function. :param args: Positional arguments to be passed to the original function. :param __context_key: Internal context key for cache separation. :param kwargs: Keyword arguments to be passed to the original function. :return: The result of the original function call. """ return func(*args, **kwargs) lock_pool = defaultdict(threading.Lock) lock = threading.Lock() @wraps(_cached_func) Loading @@ -84,8 +136,11 @@ def ts_lru_cache(**options): :return: The result of the cached function call. """ context_key = _get_context_key(level=level) with lock: return _cached_func(*args, **kwargs) _context_lock = lock_pool[context_key] with _context_lock: return _cached_func(*args, __context_key=context_key, **kwargs) # Preserve cache_info and cache_clear methods if they exist if hasattr(_cached_func, 'cache_info'): Loading
test/utils/test_cache.py +49 −0 Original line number Diff line number Diff line import threading import time from concurrent.futures import ThreadPoolExecutor from unittest.mock import patch import pytest from imgutils.utils import ts_lru_cache from imgutils.utils.cache import _get_context_key @pytest.fixture Loading Loading @@ -124,3 +127,49 @@ class TestTsLruCache: assert func(2) == 4 # This should calculate again assert call_count == 6 @pytest.fixture def reset_threading_ident(): original_get_ident = threading.get_ident try: yield finally: threading.get_ident = original_get_ident @pytest.fixture def mock_os_getpid(): with patch('os.getpid', return_value=12345) as mock_getpid: yield mock_getpid @pytest.fixture def mock_threading_get_ident(): with patch('threading.get_ident', return_value=67890) as mock_get_ident: yield mock_get_ident @pytest.mark.unittest class TestGetContextKey: def test_global_level(self): assert _get_context_key('global') is None def test_process_level(self, mock_os_getpid): assert _get_context_key('process') == 12345 mock_os_getpid.assert_called_once() def test_thread_level(self, mock_os_getpid, mock_threading_get_ident): assert _get_context_key('thread') == (12345, 67890) mock_os_getpid.assert_called_once() mock_threading_get_ident.assert_called_once() def test_invalid_level(self): with pytest.raises(ValueError) as excinfo: _get_context_key('invalid') assert "Invalid cache level" in str(excinfo.value) assert "'global', 'process' or 'thread' expected but 'invalid' found" in str(excinfo.value) def test_default_level(self): # Test that the default level is 'global' assert _get_context_key() is None