Loading docs/source/api_doc/tagging/wd14.rst +8 −0 Original line number Diff line number Diff line Loading @@ -12,3 +12,11 @@ get_wd14_tags .. autofunction:: get_wd14_tags convert_wd14_emb_to_prediction ---------------------------------------------- .. autofunction:: convert_wd14_emb_to_prediction imgutils/tagging/wd14.py +115 −19 Original line number Diff line number Diff line """ Overview: Tagging utils based on wd14 v2, inspired by `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . This module provides utilities for image tagging using WD14 taggers. It includes functions for loading models, processing images, and extracting tags. The module is inspired by the `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ project on Hugging Face. """ from functools import lru_cache from typing import List, Tuple Loading Loading @@ -51,6 +56,13 @@ _DEFAULT_MODEL_NAME = 'SwinV2_v3' def _version_support_check(model_name): """ Check if the current onnxruntime version supports the given model. :param model_name: The name of the model to check. :type model_name: str :raises EnvironmentError: If the model is not supported by the current onnxruntime version. """ if model_name.endswith('_v3') and not _IS_V3_SUPPORT: raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, ' f'please upgrade it to 1.17+ version.\n' Loading @@ -63,7 +75,7 @@ def _get_wd14_model(model_name): """ Load an ONNX model from the Hugging Face Hub. :param model_name: The name of the model. :param model_name: The name of the model to load. :type model_name: str :return: The loaded ONNX model. :rtype: ONNXModel Loading @@ -77,6 +89,14 @@ def _get_wd14_model(model_name): @lru_cache() def _get_wd14_weights(model_name): """ Load the weights for a WD14 model. :param model_name: The name of the model. :type model_name: str :return: The loaded weights. :rtype: numpy.ndarray """ _version_support_check(model_name) return np.load(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', Loading Loading @@ -111,10 +131,17 @@ def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], def _mcut_threshold(probs) -> float: """ Maximum Cut Thresholding (MCut) Compute the Maximum Cut Thresholding (MCut) for multi-label classification. This method is based on the paper: Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy for Multi-label Classification. In 11th International Symposium, IDA 2012 (pp. 172-183). :param probs: Array of probabilities. :type probs: numpy.ndarray :return: The computed threshold. :rtype: float """ sorted_probs = probs[probs.argsort()[::-1]] difs = sorted_probs[:-1] - sorted_probs[1:] Loading @@ -124,6 +151,16 @@ def _mcut_threshold(probs) -> float: def _prepare_image_for_tagging(image: ImageTyping, target_size: int): """ Prepare an image for tagging by resizing and padding it. :param image: The input image. :type image: ImageTyping :param target_size: The target size for the image. :type target_size: int :return: The prepared image as a numpy array. :rtype: numpy.ndarray """ image = load_image(image, force_background=None, mode=None) image_shape = image.size max_dim = max(image_shape) Loading Loading @@ -155,6 +192,30 @@ def _postprocess_embedding( drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): """ Post-process the embedding and prediction results. :param pred: The prediction array. :type pred: numpy.ndarray :param embedding: The embedding array. :type embedding: numpy.ndarray :param model_name: The name of the model used. :type model_name: str :param general_threshold: Threshold for general tags. :type general_threshold: float :param general_mcut_enabled: Whether to use MCut for general tags. :type general_mcut_enabled: bool :param character_threshold: Threshold for character tags. :type character_threshold: float :param character_mcut_enabled: Whether to use MCut for character tags. :type character_mcut_enabled: bool :param no_underline: Whether to remove underscores from tag names. :type no_underline: bool :param drop_overlap: Whether to drop overlapping tags. :type drop_overlap: bool :param fmt: The format of the output. :return: The post-processed results. """ tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) labels = list(zip(tag_names, pred.astype(float))) Loading Loading @@ -202,9 +263,10 @@ def get_wd14_tags( fmt=('rating', 'general', 'character'), ): """ Overview: Get tags for an image with wd14 taggers. Similar to `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . Get tags for an image using WD14 taggers. This function is similar to the `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ project on Hugging Face. :param image: The input image. :type image: ImageTyping Loading @@ -224,19 +286,17 @@ def get_wd14_tags( :type drop_overlap: bool :param fmt: Return format, default is ``('rating', 'general', 'character')``. ``embedding`` is also supported for feature extraction. :type fmt: Any :return: Prediction result based on the provided ``fmt``. In the default case, it should be a tuple of ``rating``, ``general`` and ``character``. :return: Prediction result based on the provided fmt. .. note:: About ``fmt`` argument, these are the available names: The fmt argument can include the following keys: * ``rating``, a dict containing ratings and their confidences * ``general``, a dict containing general tags and their confidences * ``character``, a dict containing character tags and their confidences * ``tag``, a dict containing all tags (including general and character, not including rating) and their confidences * ``embedding``, a 1-dim embedding of image, recommended for index building after L2 normalization * ``prediction``, a 1-dim prediction result of image - ``rating``: a dict containing ratings and their confidences - ``general``: a dict containing general tags and their confidences - ``character``: a dict containing character tags and their confidences - ``tag``: a dict containing all tags (including general and character, not including rating) and their confidences - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization - ``prediction``: a 1-dim prediction result of image Example: Here are some images for example Loading Loading @@ -299,6 +359,42 @@ def convert_wd14_emb_to_prediction( drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): """ Convert WD14 embedding to understandable prediction result. :param emb: The 1-dim extracted embedding. :type emb: numpy.ndarray :param model_name: The name of the model to use. :type model_name: str :param general_threshold: The threshold for general tags. :type general_threshold: float :param general_mcut_enabled: If True, applies MCut thresholding to general tags. :type general_mcut_enabled: bool :param character_threshold: The threshold for character tags. :type character_threshold: float :param character_mcut_enabled: If True, applies MCut thresholding to character tags. :type character_mcut_enabled: bool :param no_underline: If True, replaces underscores in tag names with spaces. :type no_underline: bool :param drop_overlap: If True, drops overlapping tags. :type drop_overlap: bool :param fmt: Return format, default is ``('rating', 'general', 'character')``. :return: Prediction result based on the provided fmt. .. note:: Only the embeddings not get normalized can be converted to understandable prediction result. Example: >>> import os >>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction >>> >>> # extract the feature embedding >>> embedding = get_wd14_tags('skadi.jpg', fmt='embedding') >>> >>> # convert to understandable result >>> rating, general, character = convert_wd14_emb_to_prediction(embedding) >>> # these 3 dicts will be the same as that returned by `get_wd14_tags('skadi.jpg')` """ z_weights = _get_wd14_weights(model_name) weights, bias = z_weights['weights'], z_weights['bias'] pred = sigmoid(emb @ weights + bias) Loading Loading
docs/source/api_doc/tagging/wd14.rst +8 −0 Original line number Diff line number Diff line Loading @@ -12,3 +12,11 @@ get_wd14_tags .. autofunction:: get_wd14_tags convert_wd14_emb_to_prediction ---------------------------------------------- .. autofunction:: convert_wd14_emb_to_prediction
imgutils/tagging/wd14.py +115 −19 Original line number Diff line number Diff line """ Overview: Tagging utils based on wd14 v2, inspired by `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . This module provides utilities for image tagging using WD14 taggers. It includes functions for loading models, processing images, and extracting tags. The module is inspired by the `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ project on Hugging Face. """ from functools import lru_cache from typing import List, Tuple Loading Loading @@ -51,6 +56,13 @@ _DEFAULT_MODEL_NAME = 'SwinV2_v3' def _version_support_check(model_name): """ Check if the current onnxruntime version supports the given model. :param model_name: The name of the model to check. :type model_name: str :raises EnvironmentError: If the model is not supported by the current onnxruntime version. """ if model_name.endswith('_v3') and not _IS_V3_SUPPORT: raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, ' f'please upgrade it to 1.17+ version.\n' Loading @@ -63,7 +75,7 @@ def _get_wd14_model(model_name): """ Load an ONNX model from the Hugging Face Hub. :param model_name: The name of the model. :param model_name: The name of the model to load. :type model_name: str :return: The loaded ONNX model. :rtype: ONNXModel Loading @@ -77,6 +89,14 @@ def _get_wd14_model(model_name): @lru_cache() def _get_wd14_weights(model_name): """ Load the weights for a WD14 model. :param model_name: The name of the model. :type model_name: str :return: The loaded weights. :rtype: numpy.ndarray """ _version_support_check(model_name) return np.load(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', Loading Loading @@ -111,10 +131,17 @@ def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], def _mcut_threshold(probs) -> float: """ Maximum Cut Thresholding (MCut) Compute the Maximum Cut Thresholding (MCut) for multi-label classification. This method is based on the paper: Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy for Multi-label Classification. In 11th International Symposium, IDA 2012 (pp. 172-183). :param probs: Array of probabilities. :type probs: numpy.ndarray :return: The computed threshold. :rtype: float """ sorted_probs = probs[probs.argsort()[::-1]] difs = sorted_probs[:-1] - sorted_probs[1:] Loading @@ -124,6 +151,16 @@ def _mcut_threshold(probs) -> float: def _prepare_image_for_tagging(image: ImageTyping, target_size: int): """ Prepare an image for tagging by resizing and padding it. :param image: The input image. :type image: ImageTyping :param target_size: The target size for the image. :type target_size: int :return: The prepared image as a numpy array. :rtype: numpy.ndarray """ image = load_image(image, force_background=None, mode=None) image_shape = image.size max_dim = max(image_shape) Loading Loading @@ -155,6 +192,30 @@ def _postprocess_embedding( drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): """ Post-process the embedding and prediction results. :param pred: The prediction array. :type pred: numpy.ndarray :param embedding: The embedding array. :type embedding: numpy.ndarray :param model_name: The name of the model used. :type model_name: str :param general_threshold: Threshold for general tags. :type general_threshold: float :param general_mcut_enabled: Whether to use MCut for general tags. :type general_mcut_enabled: bool :param character_threshold: Threshold for character tags. :type character_threshold: float :param character_mcut_enabled: Whether to use MCut for character tags. :type character_mcut_enabled: bool :param no_underline: Whether to remove underscores from tag names. :type no_underline: bool :param drop_overlap: Whether to drop overlapping tags. :type drop_overlap: bool :param fmt: The format of the output. :return: The post-processed results. """ tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) labels = list(zip(tag_names, pred.astype(float))) Loading Loading @@ -202,9 +263,10 @@ def get_wd14_tags( fmt=('rating', 'general', 'character'), ): """ Overview: Get tags for an image with wd14 taggers. Similar to `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . Get tags for an image using WD14 taggers. This function is similar to the `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ project on Hugging Face. :param image: The input image. :type image: ImageTyping Loading @@ -224,19 +286,17 @@ def get_wd14_tags( :type drop_overlap: bool :param fmt: Return format, default is ``('rating', 'general', 'character')``. ``embedding`` is also supported for feature extraction. :type fmt: Any :return: Prediction result based on the provided ``fmt``. In the default case, it should be a tuple of ``rating``, ``general`` and ``character``. :return: Prediction result based on the provided fmt. .. note:: About ``fmt`` argument, these are the available names: The fmt argument can include the following keys: * ``rating``, a dict containing ratings and their confidences * ``general``, a dict containing general tags and their confidences * ``character``, a dict containing character tags and their confidences * ``tag``, a dict containing all tags (including general and character, not including rating) and their confidences * ``embedding``, a 1-dim embedding of image, recommended for index building after L2 normalization * ``prediction``, a 1-dim prediction result of image - ``rating``: a dict containing ratings and their confidences - ``general``: a dict containing general tags and their confidences - ``character``: a dict containing character tags and their confidences - ``tag``: a dict containing all tags (including general and character, not including rating) and their confidences - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization - ``prediction``: a 1-dim prediction result of image Example: Here are some images for example Loading Loading @@ -299,6 +359,42 @@ def convert_wd14_emb_to_prediction( drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): """ Convert WD14 embedding to understandable prediction result. :param emb: The 1-dim extracted embedding. :type emb: numpy.ndarray :param model_name: The name of the model to use. :type model_name: str :param general_threshold: The threshold for general tags. :type general_threshold: float :param general_mcut_enabled: If True, applies MCut thresholding to general tags. :type general_mcut_enabled: bool :param character_threshold: The threshold for character tags. :type character_threshold: float :param character_mcut_enabled: If True, applies MCut thresholding to character tags. :type character_mcut_enabled: bool :param no_underline: If True, replaces underscores in tag names with spaces. :type no_underline: bool :param drop_overlap: If True, drops overlapping tags. :type drop_overlap: bool :param fmt: Return format, default is ``('rating', 'general', 'character')``. :return: Prediction result based on the provided fmt. .. note:: Only the embeddings not get normalized can be converted to understandable prediction result. Example: >>> import os >>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction >>> >>> # extract the feature embedding >>> embedding = get_wd14_tags('skadi.jpg', fmt='embedding') >>> >>> # convert to understandable result >>> rating, general, character = convert_wd14_emb_to_prediction(embedding) >>> # these 3 dicts will be the same as that returned by `get_wd14_tags('skadi.jpg')` """ z_weights = _get_wd14_weights(model_name) weights, bias = z_weights['weights'], z_weights['bias'] pred = sigmoid(emb @ weights + bias) Loading