Commit 6747049e authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add it for wd14 tagger

parent 756c977f
Loading
Loading
Loading
Loading
+4 −5
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@ Overview:
    Tagging utils based on wd14 v2, inspired by
    `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ .
"""
from functools import lru_cache
from typing import List, Tuple, Dict

import numpy as np
@@ -15,8 +14,8 @@ from huggingface_hub import hf_hub_download

from .format import remove_underline
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping, has_alpha_channel
from ..utils import open_onnx_model, vreplace
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model, vreplace, ts_lru_cache

SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
@@ -58,7 +57,7 @@ def _version_support_check(model_name):
                               f'If you are running on GPU, use "pip install -U onnxruntime-gpu" .')  # pragma: no cover


@lru_cache()
@ts_lru_cache()
def _get_wd14_model(model_name):
    """
    Load an ONNX model from the Hugging Face Hub.
@@ -75,7 +74,7 @@ def _get_wd14_model(model_name):
    ))


@lru_cache()
@ts_lru_cache()
def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]:
    """
    Get labels for the WD14 model.
+7 −1
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ Usage:
"""

import threading
from functools import lru_cache
from functools import lru_cache, wraps

__all__ = ['ts_lru_cache']

@@ -48,15 +48,21 @@ def ts_lru_cache(**options):

    def _decorator(func):
        @lru_cache(**options)
        @wraps(func)
        def _cached_func(*args, **kwargs):
            return func(*args, **kwargs)

        lock = threading.Lock()

        @wraps(_cached_func)
        def _new_func(*args, **kwargs):
            with lock:
                return _cached_func(*args, **kwargs)

        if hasattr(_cached_func, 'cache_info'):
            _new_func.cache_info = _cached_func.cache_info
        if hasattr(_cached_func, 'cache_clear'):
            _new_func.cache_clear = _cached_func.cache_clear
        return _new_func

    return _decorator