Unverified Commit a82a7f29 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #161 from deepghs/dev/pyolo

dev(narugo): try seperate yolo model sessions by pids and tids
parents f6bbdda1 80f76c5d
Loading
Loading
Loading
Loading
+36 −10
Original line number Diff line number Diff line
@@ -13,11 +13,14 @@ The module supports various image input types and allows customization of confid

import ast
import json
import math
import os
import threading
from collections import defaultdict
from contextlib import contextmanager
from threading import Lock
from typing import List, Optional, Tuple, Union

import math
import numpy as np
import requests
from PIL import Image
@@ -54,6 +57,27 @@ def _check_gradio_env():
                               f'Please install it with `pip install dghs-imgutils[demo]`.')


_MODEL_LOAD_LOCKS = defaultdict(Lock)
_G_ML_LOCK = Lock()


@contextmanager
def _model_load_lock():
    """
    Context manager for thread-safe model loading operations.

    This context manager ensures that model loading operations are thread-safe by using
    process-specific locks. It prevents concurrent model loading operations which could
    lead to race conditions.

    :yields: None
    """
    with _G_ML_LOCK:
        lock = _MODEL_LOAD_LOCKS[os.getpid()]
    with lock:
        yield


def _v_fix(v):
    """
    Round and convert a float value to an integer.
@@ -493,7 +517,7 @@ class YOLOModel:
        self._model_types = {}
        self._hf_token = hf_token
        self._global_lock = Lock()
        self._model_lock = Lock()
        self._model_meta_lock = Lock()

    def _get_hf_token(self) -> Optional[str]:
        """
@@ -567,8 +591,9 @@ class YOLOModel:
        :return: Tuple containing the ONNX model, maximum inference size, and labels.
        :rtype: tuple
        """
        with self._model_lock:
            if model_name not in self._models:
        cache_key = os.getpid(), threading.get_ident(), model_name
        with _model_load_lock():
            if cache_key not in self._models:
                self._check_model_name(model_name)
                model = open_onnx_model(hf_hub_download(
                    repo_id=self.repo_id,
@@ -584,12 +609,12 @@ class YOLOModel:
                    max_infer_size = 640
                names_map = _safe_eval_names_str(model_metadata.custom_metadata_map['names'])
                labels = [names_map[i] for i in range(len(names_map))]
                self._models[model_name] = (model, max_infer_size, labels)
                self._models[cache_key] = (model, max_infer_size, labels, Lock())

        return self._models[model_name]
        return self._models[cache_key]

    def _get_model_type(self, model_name: str):
        with self._model_lock:
        with self._model_meta_lock:
            if model_name not in self._model_types:
                try:
                    model_type_file = hf_hub_download(
@@ -639,10 +664,11 @@ class YOLOModel:
        >>> print(detections[0])  # First detection
        ((100, 200, 300, 400), 'person', 0.95)
        """
        model, max_infer_size, labels = self._open_model(model_name)
        model, max_infer_size, labels, exec_lock = self._open_model(model_name)
        image = load_image(image, mode='RGB')
        new_image, old_size, new_size = _image_preprocess(image, max_infer_size, allow_dynamic=allow_dynamic)
        data = rgb_encode(new_image)[None, ...]
        with exec_lock:  # make sure for each session, its execution should be linear
            output, = model.run(['output0'], {'images': data})
        model_type = self._get_model_type(model_name=model_name)
        if model_type == 'yolo':
@@ -732,7 +758,7 @@ class YOLOModel:
                       iou_threshold: float = 0.7, score_threshold: float = 0.25,
                       allow_dynamic: bool = False) \
                -> gr.AnnotatedImage:
            _, _, labels = self._open_model(model_name=model_name)
            _, _, labels, _ = self._open_model(model_name=model_name)
            _colors = list(map(str, rnd_colors(len(labels))))
            _color_map = dict(zip(labels, _colors))
            return gr.AnnotatedImage(
+61 −6
Original line number Diff line number Diff line
@@ -13,14 +13,49 @@ Usage:
    ...     # Some expensive computation
    ...     return x + y
"""

import os
import threading
from collections import defaultdict
from functools import lru_cache, wraps

from typing import Literal

__all__ = ['ts_lru_cache']

LevelTyping = Literal['global', 'process', 'thread']


def _get_context_key(level: LevelTyping = 'global'):
    """
    Get a context key based on the specified caching level.

    :param level: The caching level to use. Can be 'global', 'process', or 'thread'.
    :type level: LevelTyping

    :return: A context key appropriate for the specified level.
    :rtype: tuple or None

    :raises ValueError: If an invalid cache level is specified.

    .. note::
        The function returns:

def ts_lru_cache(**options):
        - None for 'global' level
        - Process ID for 'process' level
        - (Process ID, Thread ID) tuple for 'thread' level
    """
    if level == 'global':
        return None
    elif level == 'process':
        return os.getpid()
    elif level == 'thread':
        return os.getpid(), threading.get_ident()
    else:
        raise ValueError(f'Invalid cache level, '
                         f'\'global\', \'process\' or \'thread\' expected but {level!r} found.')


def ts_lru_cache(level: LevelTyping = 'global', **options):
    """
    A thread-safe version of the lru_cache decorator.

@@ -28,23 +63,37 @@ def ts_lru_cache(**options):
    thread-safety in multithreaded environments. It maintains the same interface
    as the built-in lru_cache, allowing you to specify options like maxsize.

    :param level: The caching level ('global', 'process', or 'thread').
    :type level: LevelTyping
    :param options: Keyword arguments to be passed to the underlying lru_cache.
    :type options: dict

    :return: A thread-safe cached version of the decorated function.
    :rtype: function

    :Example:
        >>> @ts_lru_cache(maxsize=100)
    :example:
        >>> @ts_lru_cache(level='thread', maxsize=100)
        >>> def my_function(x, y):
        ...     # Function implementation
        ...     return x + y

    .. note::
        The decorator provides three levels of caching:

        - global: Single cache shared across all processes and threads
        - process: Separate cache for each process
        - thread: Separate cache for each thread

    .. note::
        While this decorator ensures thread-safety, it may introduce some overhead
        due to lock acquisition. Use it when thread-safety is more critical than
        maximum performance in multithreaded scenarios.

    .. note::
        The decorator preserves the cache_info() and cache_clear() methods
        from the original lru_cache implementation.
    """
    _ = _get_context_key(level)

    def _decorator(func):
        """
@@ -56,19 +105,22 @@ def ts_lru_cache(**options):
        :return: The wrapped function with thread-safe caching.
        :rtype: function
        """

        @lru_cache(**options)
        @wraps(func)
        def _cached_func(*args, **kwargs):
        def _cached_func(*args, __context_key=None, **kwargs):
            """
            Cached version of the original function.

            :param args: Positional arguments to be passed to the original function.
            :param __context_key: Internal context key for cache separation.
            :param kwargs: Keyword arguments to be passed to the original function.

            :return: The result of the original function call.
            """
            return func(*args, **kwargs)

        lock_pool = defaultdict(threading.Lock)
        lock = threading.Lock()

        @wraps(_cached_func)
@@ -84,8 +136,11 @@ def ts_lru_cache(**options):

            :return: The result of the cached function call.
            """
            context_key = _get_context_key(level=level)
            with lock:
                return _cached_func(*args, **kwargs)
                _context_lock = lock_pool[context_key]
            with _context_lock:
                return _cached_func(*args, __context_key=context_key, **kwargs)

        # Preserve cache_info and cache_clear methods if they exist
        if hasattr(_cached_func, 'cache_info'):
+49 −0
Original line number Diff line number Diff line
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import patch

import pytest

from imgutils.utils import ts_lru_cache
from imgutils.utils.cache import _get_context_key


@pytest.fixture
@@ -124,3 +127,49 @@ class TestTsLruCache:

        assert func(2) == 4  # This should calculate again
        assert call_count == 6


@pytest.fixture
def reset_threading_ident():
    original_get_ident = threading.get_ident
    try:
        yield
    finally:
        threading.get_ident = original_get_ident


@pytest.fixture
def mock_os_getpid():
    with patch('os.getpid', return_value=12345) as mock_getpid:
        yield mock_getpid


@pytest.fixture
def mock_threading_get_ident():
    with patch('threading.get_ident', return_value=67890) as mock_get_ident:
        yield mock_get_ident


@pytest.mark.unittest
class TestGetContextKey:
    def test_global_level(self):
        assert _get_context_key('global') is None

    def test_process_level(self, mock_os_getpid):
        assert _get_context_key('process') == 12345
        mock_os_getpid.assert_called_once()

    def test_thread_level(self, mock_os_getpid, mock_threading_get_ident):
        assert _get_context_key('thread') == (12345, 67890)
        mock_os_getpid.assert_called_once()
        mock_threading_get_ident.assert_called_once()

    def test_invalid_level(self):
        with pytest.raises(ValueError) as excinfo:
            _get_context_key('invalid')
        assert "Invalid cache level" in str(excinfo.value)
        assert "'global', 'process' or 'thread' expected but 'invalid' found" in str(excinfo.value)

    def test_default_level(self):
        # Test that the default level is 'global'
        assert _get_context_key() is None