Commit effc3784 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): optimize model loader

parent adf5ff04
Loading
Loading
Loading
Loading
+24 −4
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ import json
import os
import threading
from collections import defaultdict
from contextlib import contextmanager
from threading import Lock
from typing import List, Optional, Tuple, Union

@@ -56,6 +57,27 @@ def _check_gradio_env():
                               f'Please install it with `pip install dghs-imgutils[demo]`.')


_MODEL_LOAD_LOCKS = defaultdict(Lock)
_G_ML_LOCK = Lock()


@contextmanager
def _model_load_lock():
    """
    Context manager for thread-safe model loading operations.

    This context manager ensures that model loading operations are thread-safe by using
    process-specific locks. It prevents concurrent model loading operations which could
    lead to race conditions.

    :yields: None
    """
    with _G_ML_LOCK:
        lock = _MODEL_LOAD_LOCKS[os.getpid()]
    with lock:
        yield


def _v_fix(v):
    """
    Round and convert a float value to an integer.
@@ -495,9 +517,7 @@ class YOLOModel:
        self._model_types = {}
        self._hf_token = hf_token
        self._global_lock = Lock()
        self._model_load_locks = defaultdict(Lock)
        self._model_meta_lock = Lock()
        self._model_exec_locks = defaultdict(Lock)

    def _get_hf_token(self) -> Optional[str]:
        """
@@ -572,7 +592,7 @@ class YOLOModel:
        :rtype: tuple
        """
        cache_key = os.getpid(), threading.get_ident(), model_name
        with self._model_load_locks[os.getpid()]:
        with _model_load_lock():
            if cache_key not in self._models:
                self._check_model_name(model_name)
                model = open_onnx_model(hf_hub_download(
@@ -589,7 +609,7 @@ class YOLOModel:
                    max_infer_size = 640
                names_map = _safe_eval_names_str(model_metadata.custom_metadata_map['names'])
                labels = [names_map[i] for i in range(len(names_map))]
                self._models[cache_key] = (model, max_infer_size, labels, self._model_exec_locks[cache_key])
                self._models[cache_key] = (model, max_infer_size, labels, Lock())

        return self._models[cache_key]