Commit 7a8c4345 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): switch to ts_lru_cache

parent 6a8c3cf6
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
@@ -74,7 +74,6 @@ Overview:
        This module requires onnxruntime version 1.18 or higher.
"""

from functools import lru_cache
from typing import Tuple, List

import numpy as np
@@ -83,7 +82,7 @@ from hbutils.testing.requires.version import VersionInfo
from huggingface_hub import hf_hub_download

from imgutils.data import ImageTyping
from imgutils.utils import open_onnx_model
from imgutils.utils import open_onnx_model, ts_lru_cache
from ..data import load_image


@@ -104,7 +103,7 @@ def _check_compatibility() -> bool:
_REPO_ID = 'deepghs/nudenet_onnx'


@lru_cache()
@ts_lru_cache()
def _open_nudenet_yolo():
    """
    Open and cache the NudeNet YOLO ONNX model.
@@ -118,7 +117,7 @@ def _open_nudenet_yolo():
    ))


@lru_cache()
@ts_lru_cache()
def _open_nudenet_nms():
    """
    Open and cache the NudeNet NMS ONNX model.
+2 −3
Original line number Diff line number Diff line
@@ -23,7 +23,6 @@ Overview:
            :align: center

"""
from functools import lru_cache
from typing import List, Tuple, Optional

import cv2
@@ -33,12 +32,12 @@ from huggingface_hub import hf_hub_download

from ..config.meta import __VERSION__
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model
from ..utils import open_onnx_model, ts_lru_cache

_DEFAULT_MODEL = 'dbnetpp_resnet50_fpnc_1200e_icdar2015'


@lru_cache()
@ts_lru_cache()
def _open_text_detect_model(model: str):
    """
    Get an ONNX session for the specified DBNET or DBNET++ model.
+3 −3
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ Overview:
    Having the **best effect**, closest to the drawing lines,
    but consuming a large amount of memory and computing power at runtime.
"""
from functools import lru_cache, partial
from functools import partial
from typing import Optional

import numpy as np
@@ -14,7 +14,7 @@ from huggingface_hub import hf_hub_download

from ._base import resize_image, cv2_resize, _get_image_edge
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model
from ..utils import open_onnx_model, ts_lru_cache


def _preprocess(input_image: Image.Image, detect_resolution: int = 512):
@@ -23,7 +23,7 @@ def _preprocess(input_image: Image.Image, detect_resolution: int = 512):
    return (input_image / 255.0).transpose(2, 0, 1)[None, ...].astype(np.float32)


@lru_cache()
@ts_lru_cache()
def _open_la_model(coarse: bool):
    return open_onnx_model(hf_hub_download(
        'deepghs/imgutils-models',
+3 −3
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@
Overview:
    Get edge with lineart anime model.
"""
from functools import lru_cache, partial
from functools import partial
from typing import Optional

import numpy as np
@@ -10,7 +10,7 @@ from huggingface_hub import hf_hub_download

from ._base import resize_image, cv2_resize, _get_image_edge
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model
from ..utils import open_onnx_model, ts_lru_cache


def _preprocess(input_image, detect_resolution: int = 512):
@@ -20,7 +20,7 @@ def _preprocess(input_image, detect_resolution: int = 512):
    return img


@lru_cache()
@ts_lru_cache()
def _open_la_anime_model():
    return open_onnx_model(hf_hub_download(
        'deepghs/imgutils-models',
+35 −28
Original line number Diff line number Diff line
@@ -18,7 +18,7 @@ It also handles token-based authentication for accessing private Hugging Face re

import json
import os
from functools import lru_cache
from threading import Lock
from typing import Tuple, Optional, List, Dict

import numpy as np
@@ -29,7 +29,7 @@ from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import hf_hub_download, HfFileSystem

from ..data import rgb_encode, ImageTyping, load_image
from ..utils import open_onnx_model
from ..utils import open_onnx_model, ts_lru_cache

try:
    import gradio as gr
@@ -128,6 +128,8 @@ class ClassifyModel:
        self._models = {}
        self._labels = {}
        self._hf_token = hf_token
        self._global_lock = Lock()
        self._model_lock = Lock()

    def _get_hf_token(self) -> Optional[str]:
        """
@@ -151,6 +153,7 @@ class ClassifyModel:

        :raises RuntimeError: If there's an error accessing the Hugging Face repository.
        """
        with self._global_lock:
            if self._model_names is None:
                hf_fs = HfFileSystem(token=self._get_hf_token())
                self._model_names = [
@@ -191,6 +194,7 @@ class ClassifyModel:

        :raises RuntimeError: If there's an error downloading or opening the model.
        """
        with self._model_lock:
            if model_name not in self._models:
                self._check_model_name(model_name)
                self._models[model_name] = open_onnx_model(hf_hub_download(
@@ -198,6 +202,7 @@ class ClassifyModel:
                    f'{model_name}/model.onnx',
                    token=self._get_hf_token(),
                ))

        return self._models[model_name]

    def _open_label(self, model_name: str) -> List[str]:
@@ -214,6 +219,7 @@ class ClassifyModel:

        :raises RuntimeError: If there's an error downloading or parsing the labels file.
        """
        with self._model_lock:
            if model_name not in self._labels:
                self._check_model_name(model_name)
                with open(hf_hub_download(
@@ -222,6 +228,7 @@ class ClassifyModel:
                        token=self._get_hf_token(),
                ), 'r') as f:
                    self._labels[model_name] = json.load(f)['labels']

        return self._labels[model_name]

    def _raw_predict(self, image: ImageTyping, model_name: str):
@@ -400,7 +407,7 @@ class ClassifyModel:
        )


@lru_cache()
@ts_lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> ClassifyModel:
    """
    Open and cache a ClassifyModel instance for the specified repository ID.
Loading