Loading docs/source/api_doc/tagging/wd14.rst +8 −0 Original line number Diff line number Diff line Loading @@ -12,3 +12,11 @@ get_wd14_tags .. autofunction:: get_wd14_tags convert_wd14_emb_to_prediction ---------------------------------------------- .. autofunction:: convert_wd14_emb_to_prediction docs/source/api_doc/utils/func.rst 0 → 100644 +14 −0 Original line number Diff line number Diff line imgutils.utils.func ==================================== .. currentmodule:: imgutils.utils.func .. automodule:: imgutils.utils.func sigmoid ------------------------- .. autofunction:: sigmoid docs/source/api_doc/utils/index.rst +1 −0 Original line number Diff line number Diff line Loading @@ -10,4 +10,5 @@ imgutils.utils :maxdepth: 3 cache func onnxruntime imgutils/tagging/__init__.py +1 −1 Original line number Diff line number Diff line Loading @@ -16,4 +16,4 @@ from .match import tag_match_suffix, tag_match_prefix, tag_match_full from .mldanbooru import get_mldanbooru_tags from .order import sort_tags from .overlap import drop_overlap_tags from .wd14 import get_wd14_tags from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction imgutils/tagging/wd14.py +217 −48 Original line number Diff line number Diff line """ Overview: Tagging utils based on wd14 v2, inspired by `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . This module provides utilities for image tagging using WD14 taggers. It includes functions for loading models, processing images, and extracting tags. The module is inspired by the `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ project on Hugging Face. """ from typing import List, Tuple, Dict from typing import List, Tuple import numpy as np import onnxruntime Loading @@ -15,7 +19,7 @@ from huggingface_hub import hf_hub_download from .format import remove_underline from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping from ..utils import open_onnx_model, vreplace, ts_lru_cache from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" Loading Loading @@ -50,6 +54,13 @@ _DEFAULT_MODEL_NAME = 'SwinV2_v3' def _version_support_check(model_name): """ Check if the current onnxruntime version supports the given model. :param model_name: The name of the model to check. :type model_name: str :raises EnvironmentError: If the model is not supported by the current onnxruntime version. """ if model_name.endswith('_v3') and not _IS_V3_SUPPORT: raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, ' f'please upgrade it to 1.17+ version.\n' Loading @@ -62,7 +73,7 @@ def _get_wd14_model(model_name): """ Load an ONNX model from the Hugging Face Hub. :param model_name: The name of the model. :param model_name: The name of the model to load. :type model_name: str :return: The loaded ONNX model. :rtype: ONNXModel Loading @@ -74,6 +85,23 @@ def _get_wd14_model(model_name): )) @ts_lru_cache() def _get_wd14_weights(model_name): """ Load the weights for a WD14 model. :param model_name: The name of the model. :type model_name: str :return: The loaded weights. :rtype: numpy.ndarray """ _version_support_check(model_name) return np.load(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', filename=f'{MODEL_NAMES[model_name]}/inv.npz', )) @ts_lru_cache() def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]: """ Loading Loading @@ -101,10 +129,17 @@ def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], def _mcut_threshold(probs) -> float: """ Maximum Cut Thresholding (MCut) Compute the Maximum Cut Thresholding (MCut) for multi-label classification. This method is based on the paper: Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy for Multi-label Classification. In 11th International Symposium, IDA 2012 (pp. 172-183). :param probs: Array of probabilities. :type probs: numpy.ndarray :return: The computed threshold. :rtype: float """ sorted_probs = probs[probs.argsort()[::-1]] difs = sorted_probs[:-1] - sorted_probs[1:] Loading @@ -114,6 +149,16 @@ def _mcut_threshold(probs) -> float: def _prepare_image_for_tagging(image: ImageTyping, target_size: int): """ Prepare an image for tagging by resizing and padding it. :param image: The input image. :type image: ImageTyping :param target_size: The target size for the image. :type target_size: int :return: The prepared image as a numpy array. :rtype: numpy.ndarray """ image = load_image(image, force_background=None, mode=None) image_shape = image.size max_dim = max(image_shape) Loading @@ -134,6 +179,76 @@ def _prepare_image_for_tagging(image: ImageTyping, target_size: int): return np.expand_dims(image_array, axis=0) def _postprocess_embedding( pred, embedding, model_name: str = _DEFAULT_MODEL_NAME, general_threshold: float = 0.35, general_mcut_enabled: bool = False, character_threshold: float = 0.85, character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): """ Post-process the embedding and prediction results. :param pred: The prediction array. :type pred: numpy.ndarray :param embedding: The embedding array. :type embedding: numpy.ndarray :param model_name: The name of the model used. :type model_name: str :param general_threshold: Threshold for general tags. :type general_threshold: float :param general_mcut_enabled: Whether to use MCut for general tags. :type general_mcut_enabled: bool :param character_threshold: Threshold for character tags. :type character_threshold: float :param character_mcut_enabled: Whether to use MCut for character tags. :type character_mcut_enabled: bool :param no_underline: Whether to remove underscores from tag names. :type no_underline: bool :param drop_overlap: Whether to drop overlapping tags. :type drop_overlap: bool :param fmt: The format of the output. :return: The post-processed results. """ tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) labels = list(zip(tag_names, pred.astype(float))) rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = {x: v.item() for x, v in general_names if v > general_threshold} if drop_overlap: general_res = drop_overlap_tags(general_res) character_names = [labels[i] for i in character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) character_res = {x: v.item() for x, v in character_names if v > character_threshold} return vreplace( fmt, { 'rating': rating, 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, 'embedding': embedding.astype(np.float32), 'prediction': pred.astype(np.float32), } ) def get_wd14_tags( image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, Loading @@ -146,9 +261,10 @@ def get_wd14_tags( fmt=('rating', 'general', 'character'), ): """ Overview: Get tags for an image with wd14 taggers. Similar to `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . Get tags for an image using WD14 taggers. This function is similar to the `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ project on Hugging Face. :param image: The input image. :type image: ImageTyping Loading @@ -168,19 +284,28 @@ def get_wd14_tags( :type drop_overlap: bool :param fmt: Return format, default is ``('rating', 'general', 'character')``. ``embedding`` is also supported for feature extraction. :type fmt: Any :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] :return: Prediction result based on the provided fmt. .. note:: About ``fmt`` argument, these are the available names: The fmt argument can include the following keys: - ``rating``: a dict containing ratings and their confidences - ``general``: a dict containing general tags and their confidences - ``character``: a dict containing character tags and their confidences - ``tag``: a dict containing all tags (including general and character, not including rating) and their confidences - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization - ``prediction``: a 1-dim prediction result of image * ``rating``, a dict containing ratings and their confidences * ``general``, a dict containing general tags and their confidences * ``character``, a dict containing character tags and their confidences * ``tag``, a dict containing all tags (including general and character, not including rating) and their confidences * ``embedding``, a 1-dim embedding of image, recommended for index building after L2 normalization * ``prediction``, a 1-dim prediction result of image You can extract embedding of the given image with the follwing code >>> from imgutils.tagging import get_wd14_tags >>> >>> embedding = get_wd14_tags('skadi.jpg', fmt='embdding') >>> embedding.shape (1024, ) This embedding is valuable for constructing indices that enable rapid querying of images based on visual features within large-scale datasets. Example: Here are some images for example Loading @@ -188,7 +313,6 @@ def get_wd14_tags( .. image:: tagging_demo.plot.py.svg :align: center >>> import os >>> from imgutils.tagging import get_wd14_tags >>> >>> rating, features, chars = get_wd14_tags('skadi.jpg') Loading @@ -207,7 +331,7 @@ def get_wd14_tags( >>> chars {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541} """ tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) model = _get_wd14_model(model_name) _, target_size, _, _ = model.get_inputs()[0].shape image = _prepare_image_for_tagging(image, target_size) Loading @@ -217,35 +341,80 @@ def get_wd14_tags( label_name = model.get_outputs()[0].name emb_name = model.get_outputs()[1].name preds, embeddings = model.run([label_name, emb_name], {input_name: image}) labels = list(zip(tag_names, preds[0].astype(float))) rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} return _postprocess_embedding( pred=preds[0], embedding=embeddings[0], model_name=model_name, general_threshold=general_threshold, general_mcut_enabled=general_mcut_enabled, character_threshold=character_threshold, character_mcut_enabled=character_mcut_enabled, no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, ) general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = {x: v.item() for x, v in general_names if v > general_threshold} if drop_overlap: general_res = drop_overlap_tags(general_res) def convert_wd14_emb_to_prediction( emb: np.ndarray, model_name: str = _DEFAULT_MODEL_NAME, general_threshold: float = 0.35, general_mcut_enabled: bool = False, character_threshold: float = 0.85, character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): """ Convert WD14 embedding to understandable prediction result. character_names = [labels[i] for i in character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) :param emb: The 1-dim extracted embedding. :type emb: numpy.ndarray :param model_name: The name of the model to use. :type model_name: str :param general_threshold: The threshold for general tags. :type general_threshold: float :param general_mcut_enabled: If True, applies MCut thresholding to general tags. :type general_mcut_enabled: bool :param character_threshold: The threshold for character tags. :type character_threshold: float :param character_mcut_enabled: If True, applies MCut thresholding to character tags. :type character_mcut_enabled: bool :param no_underline: If True, replaces underscores in tag names with spaces. :type no_underline: bool :param drop_overlap: If True, drops overlapping tags. :type drop_overlap: bool :param fmt: Return format, default is ``('rating', 'general', 'character')``. :return: Prediction result based on the provided fmt. character_res = {x: v.item() for x, v in character_names if v > character_threshold} .. note:: Only the embeddings not get normalized can be converted to understandable prediction result. return vreplace( fmt, { 'rating': rating, 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, 'embedding': embeddings[0].astype(np.float32), 'prediction': preds[0].astype(np.float32), } Example: >>> import os >>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction >>> >>> # extract the feature embedding >>> embedding = get_wd14_tags('skadi.jpg', fmt='embedding') >>> >>> # convert to understandable result >>> rating, general, character = convert_wd14_emb_to_prediction(embedding) >>> # these 3 dicts will be the same as that returned by `get_wd14_tags('skadi.jpg')` """ z_weights = _get_wd14_weights(model_name) weights, bias = z_weights['weights'], z_weights['bias'] pred = sigmoid(emb @ weights + bias) return _postprocess_embedding( pred=pred, embedding=emb, model_name=model_name, general_threshold=general_threshold, general_mcut_enabled=general_mcut_enabled, character_threshold=character_threshold, character_mcut_enabled=character_mcut_enabled, no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, ) Loading
docs/source/api_doc/tagging/wd14.rst +8 −0 Original line number Diff line number Diff line Loading @@ -12,3 +12,11 @@ get_wd14_tags .. autofunction:: get_wd14_tags convert_wd14_emb_to_prediction ---------------------------------------------- .. autofunction:: convert_wd14_emb_to_prediction
docs/source/api_doc/utils/func.rst 0 → 100644 +14 −0 Original line number Diff line number Diff line imgutils.utils.func ==================================== .. currentmodule:: imgutils.utils.func .. automodule:: imgutils.utils.func sigmoid ------------------------- .. autofunction:: sigmoid
docs/source/api_doc/utils/index.rst +1 −0 Original line number Diff line number Diff line Loading @@ -10,4 +10,5 @@ imgutils.utils :maxdepth: 3 cache func onnxruntime
imgutils/tagging/__init__.py +1 −1 Original line number Diff line number Diff line Loading @@ -16,4 +16,4 @@ from .match import tag_match_suffix, tag_match_prefix, tag_match_full from .mldanbooru import get_mldanbooru_tags from .order import sort_tags from .overlap import drop_overlap_tags from .wd14 import get_wd14_tags from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction
imgutils/tagging/wd14.py +217 −48 Original line number Diff line number Diff line """ Overview: Tagging utils based on wd14 v2, inspired by `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . This module provides utilities for image tagging using WD14 taggers. It includes functions for loading models, processing images, and extracting tags. The module is inspired by the `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ project on Hugging Face. """ from typing import List, Tuple, Dict from typing import List, Tuple import numpy as np import onnxruntime Loading @@ -15,7 +19,7 @@ from huggingface_hub import hf_hub_download from .format import remove_underline from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping from ..utils import open_onnx_model, vreplace, ts_lru_cache from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" Loading Loading @@ -50,6 +54,13 @@ _DEFAULT_MODEL_NAME = 'SwinV2_v3' def _version_support_check(model_name): """ Check if the current onnxruntime version supports the given model. :param model_name: The name of the model to check. :type model_name: str :raises EnvironmentError: If the model is not supported by the current onnxruntime version. """ if model_name.endswith('_v3') and not _IS_V3_SUPPORT: raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, ' f'please upgrade it to 1.17+ version.\n' Loading @@ -62,7 +73,7 @@ def _get_wd14_model(model_name): """ Load an ONNX model from the Hugging Face Hub. :param model_name: The name of the model. :param model_name: The name of the model to load. :type model_name: str :return: The loaded ONNX model. :rtype: ONNXModel Loading @@ -74,6 +85,23 @@ def _get_wd14_model(model_name): )) @ts_lru_cache() def _get_wd14_weights(model_name): """ Load the weights for a WD14 model. :param model_name: The name of the model. :type model_name: str :return: The loaded weights. :rtype: numpy.ndarray """ _version_support_check(model_name) return np.load(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', filename=f'{MODEL_NAMES[model_name]}/inv.npz', )) @ts_lru_cache() def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]: """ Loading Loading @@ -101,10 +129,17 @@ def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], def _mcut_threshold(probs) -> float: """ Maximum Cut Thresholding (MCut) Compute the Maximum Cut Thresholding (MCut) for multi-label classification. This method is based on the paper: Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy for Multi-label Classification. In 11th International Symposium, IDA 2012 (pp. 172-183). :param probs: Array of probabilities. :type probs: numpy.ndarray :return: The computed threshold. :rtype: float """ sorted_probs = probs[probs.argsort()[::-1]] difs = sorted_probs[:-1] - sorted_probs[1:] Loading @@ -114,6 +149,16 @@ def _mcut_threshold(probs) -> float: def _prepare_image_for_tagging(image: ImageTyping, target_size: int): """ Prepare an image for tagging by resizing and padding it. :param image: The input image. :type image: ImageTyping :param target_size: The target size for the image. :type target_size: int :return: The prepared image as a numpy array. :rtype: numpy.ndarray """ image = load_image(image, force_background=None, mode=None) image_shape = image.size max_dim = max(image_shape) Loading @@ -134,6 +179,76 @@ def _prepare_image_for_tagging(image: ImageTyping, target_size: int): return np.expand_dims(image_array, axis=0) def _postprocess_embedding( pred, embedding, model_name: str = _DEFAULT_MODEL_NAME, general_threshold: float = 0.35, general_mcut_enabled: bool = False, character_threshold: float = 0.85, character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): """ Post-process the embedding and prediction results. :param pred: The prediction array. :type pred: numpy.ndarray :param embedding: The embedding array. :type embedding: numpy.ndarray :param model_name: The name of the model used. :type model_name: str :param general_threshold: Threshold for general tags. :type general_threshold: float :param general_mcut_enabled: Whether to use MCut for general tags. :type general_mcut_enabled: bool :param character_threshold: Threshold for character tags. :type character_threshold: float :param character_mcut_enabled: Whether to use MCut for character tags. :type character_mcut_enabled: bool :param no_underline: Whether to remove underscores from tag names. :type no_underline: bool :param drop_overlap: Whether to drop overlapping tags. :type drop_overlap: bool :param fmt: The format of the output. :return: The post-processed results. """ tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) labels = list(zip(tag_names, pred.astype(float))) rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = {x: v.item() for x, v in general_names if v > general_threshold} if drop_overlap: general_res = drop_overlap_tags(general_res) character_names = [labels[i] for i in character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) character_res = {x: v.item() for x, v in character_names if v > character_threshold} return vreplace( fmt, { 'rating': rating, 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, 'embedding': embedding.astype(np.float32), 'prediction': pred.astype(np.float32), } ) def get_wd14_tags( image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, Loading @@ -146,9 +261,10 @@ def get_wd14_tags( fmt=('rating', 'general', 'character'), ): """ Overview: Get tags for an image with wd14 taggers. Similar to `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . Get tags for an image using WD14 taggers. This function is similar to the `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ project on Hugging Face. :param image: The input image. :type image: ImageTyping Loading @@ -168,19 +284,28 @@ def get_wd14_tags( :type drop_overlap: bool :param fmt: Return format, default is ``('rating', 'general', 'character')``. ``embedding`` is also supported for feature extraction. :type fmt: Any :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] :return: Prediction result based on the provided fmt. .. note:: About ``fmt`` argument, these are the available names: The fmt argument can include the following keys: - ``rating``: a dict containing ratings and their confidences - ``general``: a dict containing general tags and their confidences - ``character``: a dict containing character tags and their confidences - ``tag``: a dict containing all tags (including general and character, not including rating) and their confidences - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization - ``prediction``: a 1-dim prediction result of image * ``rating``, a dict containing ratings and their confidences * ``general``, a dict containing general tags and their confidences * ``character``, a dict containing character tags and their confidences * ``tag``, a dict containing all tags (including general and character, not including rating) and their confidences * ``embedding``, a 1-dim embedding of image, recommended for index building after L2 normalization * ``prediction``, a 1-dim prediction result of image You can extract embedding of the given image with the follwing code >>> from imgutils.tagging import get_wd14_tags >>> >>> embedding = get_wd14_tags('skadi.jpg', fmt='embdding') >>> embedding.shape (1024, ) This embedding is valuable for constructing indices that enable rapid querying of images based on visual features within large-scale datasets. Example: Here are some images for example Loading @@ -188,7 +313,6 @@ def get_wd14_tags( .. image:: tagging_demo.plot.py.svg :align: center >>> import os >>> from imgutils.tagging import get_wd14_tags >>> >>> rating, features, chars = get_wd14_tags('skadi.jpg') Loading @@ -207,7 +331,7 @@ def get_wd14_tags( >>> chars {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541} """ tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) model = _get_wd14_model(model_name) _, target_size, _, _ = model.get_inputs()[0].shape image = _prepare_image_for_tagging(image, target_size) Loading @@ -217,35 +341,80 @@ def get_wd14_tags( label_name = model.get_outputs()[0].name emb_name = model.get_outputs()[1].name preds, embeddings = model.run([label_name, emb_name], {input_name: image}) labels = list(zip(tag_names, preds[0].astype(float))) rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} return _postprocess_embedding( pred=preds[0], embedding=embeddings[0], model_name=model_name, general_threshold=general_threshold, general_mcut_enabled=general_mcut_enabled, character_threshold=character_threshold, character_mcut_enabled=character_mcut_enabled, no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, ) general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = {x: v.item() for x, v in general_names if v > general_threshold} if drop_overlap: general_res = drop_overlap_tags(general_res) def convert_wd14_emb_to_prediction( emb: np.ndarray, model_name: str = _DEFAULT_MODEL_NAME, general_threshold: float = 0.35, general_mcut_enabled: bool = False, character_threshold: float = 0.85, character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): """ Convert WD14 embedding to understandable prediction result. character_names = [labels[i] for i in character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) :param emb: The 1-dim extracted embedding. :type emb: numpy.ndarray :param model_name: The name of the model to use. :type model_name: str :param general_threshold: The threshold for general tags. :type general_threshold: float :param general_mcut_enabled: If True, applies MCut thresholding to general tags. :type general_mcut_enabled: bool :param character_threshold: The threshold for character tags. :type character_threshold: float :param character_mcut_enabled: If True, applies MCut thresholding to character tags. :type character_mcut_enabled: bool :param no_underline: If True, replaces underscores in tag names with spaces. :type no_underline: bool :param drop_overlap: If True, drops overlapping tags. :type drop_overlap: bool :param fmt: Return format, default is ``('rating', 'general', 'character')``. :return: Prediction result based on the provided fmt. character_res = {x: v.item() for x, v in character_names if v > character_threshold} .. note:: Only the embeddings not get normalized can be converted to understandable prediction result. return vreplace( fmt, { 'rating': rating, 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, 'embedding': embeddings[0].astype(np.float32), 'prediction': preds[0].astype(np.float32), } Example: >>> import os >>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction >>> >>> # extract the feature embedding >>> embedding = get_wd14_tags('skadi.jpg', fmt='embedding') >>> >>> # convert to understandable result >>> rating, general, character = convert_wd14_emb_to_prediction(embedding) >>> # these 3 dicts will be the same as that returned by `get_wd14_tags('skadi.jpg')` """ z_weights = _get_wd14_weights(model_name) weights, bias = z_weights['weights'], z_weights['bias'] pred = sigmoid(emb @ weights + bias) return _postprocess_embedding( pred=pred, embedding=emb, model_name=model_name, general_threshold=general_threshold, general_mcut_enabled=general_mcut_enabled, character_threshold=character_threshold, character_mcut_enabled=character_mcut_enabled, no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, )