Commit 6a8c3cf6 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): merge from main

parents 2feb78cf c1bca124
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -12,3 +12,11 @@ get_wd14_tags
.. autofunction:: get_wd14_tags



convert_wd14_emb_to_prediction
----------------------------------------------

.. autofunction:: convert_wd14_emb_to_prediction


+14 −0
Original line number Diff line number Diff line
imgutils.utils.func
====================================

.. currentmodule:: imgutils.utils.func

.. automodule:: imgutils.utils.func


sigmoid
-------------------------

.. autofunction:: sigmoid

+1 −0
Original line number Diff line number Diff line
@@ -10,4 +10,5 @@ imgutils.utils
    :maxdepth: 3

    cache
    func
    onnxruntime
+1 −1
Original line number Diff line number Diff line
@@ -16,4 +16,4 @@ from .match import tag_match_suffix, tag_match_prefix, tag_match_full
from .mldanbooru import get_mldanbooru_tags
from .order import sort_tags
from .overlap import drop_overlap_tags
from .wd14 import get_wd14_tags
from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction
+217 −48
Original line number Diff line number Diff line
"""
Overview:
    Tagging utils based on wd14 v2, inspired by
    `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ .
    This module provides utilities for image tagging using WD14 taggers.
    It includes functions for loading models, processing images, and extracting tags.

    The module is inspired by the `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_
    project on Hugging Face.

"""
from typing import List, Tuple, Dict
from typing import List, Tuple

import numpy as np
import onnxruntime
@@ -15,7 +19,7 @@ from huggingface_hub import hf_hub_download
from .format import remove_underline
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model, vreplace, ts_lru_cache
from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache

SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
@@ -50,6 +54,13 @@ _DEFAULT_MODEL_NAME = 'SwinV2_v3'


def _version_support_check(model_name):
    """
    Check if the current onnxruntime version supports the given model.

    :param model_name: The name of the model to check.
    :type model_name: str
    :raises EnvironmentError: If the model is not supported by the current onnxruntime version.
    """
    if model_name.endswith('_v3') and not _IS_V3_SUPPORT:
        raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, '
                               f'please upgrade it to 1.17+ version.\n'
@@ -62,7 +73,7 @@ def _get_wd14_model(model_name):
    """
    Load an ONNX model from the Hugging Face Hub.

    :param model_name: The name of the model.
    :param model_name: The name of the model to load.
    :type model_name: str
    :return: The loaded ONNX model.
    :rtype: ONNXModel
@@ -74,6 +85,23 @@ def _get_wd14_model(model_name):
    ))


@ts_lru_cache()
def _get_wd14_weights(model_name):
    """
    Load the weights for a WD14 model.

    :param model_name: The name of the model.
    :type model_name: str
    :return: The loaded weights.
    :rtype: numpy.ndarray
    """
    _version_support_check(model_name)
    return np.load(hf_hub_download(
        repo_id='deepghs/wd14_tagger_with_embeddings',
        filename=f'{MODEL_NAMES[model_name]}/inv.npz',
    ))


@ts_lru_cache()
def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]:
    """
@@ -101,10 +129,17 @@ def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str],

def _mcut_threshold(probs) -> float:
    """
    Maximum Cut Thresholding (MCut)
    Compute the Maximum Cut Thresholding (MCut) for multi-label classification.

    This method is based on the paper:
    Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
    for Multi-label Classification. In 11th International Symposium, IDA 2012
    (pp. 172-183).

    :param probs: Array of probabilities.
    :type probs: numpy.ndarray
    :return: The computed threshold.
    :rtype: float
    """
    sorted_probs = probs[probs.argsort()[::-1]]
    difs = sorted_probs[:-1] - sorted_probs[1:]
@@ -114,6 +149,16 @@ def _mcut_threshold(probs) -> float:


def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
    """
    Prepare an image for tagging by resizing and padding it.

    :param image: The input image.
    :type image: ImageTyping
    :param target_size: The target size for the image.
    :type target_size: int
    :return: The prepared image as a numpy array.
    :rtype: numpy.ndarray
    """
    image = load_image(image, force_background=None, mode=None)
    image_shape = image.size
    max_dim = max(image_shape)
@@ -134,6 +179,76 @@ def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
    return np.expand_dims(image_array, axis=0)


def _postprocess_embedding(
        pred, embedding,
        model_name: str = _DEFAULT_MODEL_NAME,
        general_threshold: float = 0.35,
        general_mcut_enabled: bool = False,
        character_threshold: float = 0.85,
        character_mcut_enabled: bool = False,
        no_underline: bool = False,
        drop_overlap: bool = False,
        fmt=('rating', 'general', 'character'),
):
    """
    Post-process the embedding and prediction results.

    :param pred: The prediction array.
    :type pred: numpy.ndarray
    :param embedding: The embedding array.
    :type embedding: numpy.ndarray
    :param model_name: The name of the model used.
    :type model_name: str
    :param general_threshold: Threshold for general tags.
    :type general_threshold: float
    :param general_mcut_enabled: Whether to use MCut for general tags.
    :type general_mcut_enabled: bool
    :param character_threshold: Threshold for character tags.
    :type character_threshold: float
    :param character_mcut_enabled: Whether to use MCut for character tags.
    :type character_mcut_enabled: bool
    :param no_underline: Whether to remove underscores from tag names.
    :type no_underline: bool
    :param drop_overlap: Whether to drop overlapping tags.
    :type drop_overlap: bool
    :param fmt: The format of the output.
    :return: The post-processed results.
    """
    tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline)
    labels = list(zip(tag_names, pred.astype(float)))

    rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes}

    general_names = [labels[i] for i in general_indexes]
    if general_mcut_enabled:
        general_probs = np.array([x[1] for x in general_names])
        general_threshold = _mcut_threshold(general_probs)

    general_res = {x: v.item() for x, v in general_names if v > general_threshold}
    if drop_overlap:
        general_res = drop_overlap_tags(general_res)

    character_names = [labels[i] for i in character_indexes]
    if character_mcut_enabled:
        character_probs = np.array([x[1] for x in character_names])
        character_threshold = _mcut_threshold(character_probs)
        character_threshold = max(0.15, character_threshold)

    character_res = {x: v.item() for x, v in character_names if v > character_threshold}

    return vreplace(
        fmt,
        {
            'rating': rating,
            'general': general_res,
            'character': character_res,
            'tag': {**general_res, **character_res},
            'embedding': embedding.astype(np.float32),
            'prediction': pred.astype(np.float32),
        }
    )


def get_wd14_tags(
        image: ImageTyping,
        model_name: str = _DEFAULT_MODEL_NAME,
@@ -146,9 +261,10 @@ def get_wd14_tags(
        fmt=('rating', 'general', 'character'),
):
    """
    Overview:
        Get tags for an image with wd14 taggers.
        Similar to `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ .
    Get tags for an image using WD14 taggers.

    This function is similar to the
    `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ project on Hugging Face.

    :param image: The input image.
    :type image: ImageTyping
@@ -168,19 +284,28 @@ def get_wd14_tags(
    :type drop_overlap: bool
    :param fmt: Return format, default is ``('rating', 'general', 'character')``.
        ``embedding`` is also supported for feature extraction.
    :type fmt: Any
    :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities.
    :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]
    :return: Prediction result based on the provided fmt.

    .. note::
        About ``fmt`` argument, these are the available names:
        The fmt argument can include the following keys:

        - ``rating``: a dict containing ratings and their confidences
        - ``general``: a dict containing general tags and their confidences
        - ``character``: a dict containing character tags and their confidences
        - ``tag``: a dict containing all tags (including general and character, not including rating) and their confidences
        - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization
        - ``prediction``: a 1-dim prediction result of image

        * ``rating``, a dict containing ratings and their confidences
        * ``general``, a dict containing general tags and their confidences
        * ``character``, a dict containing character tags and their confidences
        * ``tag``, a dict containing all tags (including general and character, not including rating) and their confidences
        * ``embedding``, a 1-dim embedding of image, recommended for index building after L2 normalization
        * ``prediction``, a 1-dim prediction result of image
        You can extract embedding of the given image with the follwing code

        >>> from imgutils.tagging import get_wd14_tags
        >>>
        >>> embedding = get_wd14_tags('skadi.jpg', fmt='embdding')
        >>> embedding.shape
        (1024, )

        This embedding is valuable for constructing indices that enable rapid querying of images based on
        visual features within large-scale datasets.

    Example:
        Here are some images for example
@@ -188,7 +313,6 @@ def get_wd14_tags(
        .. image:: tagging_demo.plot.py.svg
           :align: center

        >>> import os
        >>> from imgutils.tagging import get_wd14_tags
        >>>
        >>> rating, features, chars = get_wd14_tags('skadi.jpg')
@@ -207,7 +331,7 @@ def get_wd14_tags(
        >>> chars
        {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541}
    """
    tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline)

    model = _get_wd14_model(model_name)
    _, target_size, _, _ = model.get_inputs()[0].shape
    image = _prepare_image_for_tagging(image, target_size)
@@ -217,35 +341,80 @@ def get_wd14_tags(
    label_name = model.get_outputs()[0].name
    emb_name = model.get_outputs()[1].name
    preds, embeddings = model.run([label_name, emb_name], {input_name: image})
    labels = list(zip(tag_names, preds[0].astype(float)))

    rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes}
    return _postprocess_embedding(
        pred=preds[0],
        embedding=embeddings[0],
        model_name=model_name,
        general_threshold=general_threshold,
        general_mcut_enabled=general_mcut_enabled,
        character_threshold=character_threshold,
        character_mcut_enabled=character_mcut_enabled,
        no_underline=no_underline,
        drop_overlap=drop_overlap,
        fmt=fmt,
    )

    general_names = [labels[i] for i in general_indexes]
    if general_mcut_enabled:
        general_probs = np.array([x[1] for x in general_names])
        general_threshold = _mcut_threshold(general_probs)

    general_res = {x: v.item() for x, v in general_names if v > general_threshold}
    if drop_overlap:
        general_res = drop_overlap_tags(general_res)
def convert_wd14_emb_to_prediction(
        emb: np.ndarray,
        model_name: str = _DEFAULT_MODEL_NAME,
        general_threshold: float = 0.35,
        general_mcut_enabled: bool = False,
        character_threshold: float = 0.85,
        character_mcut_enabled: bool = False,
        no_underline: bool = False,
        drop_overlap: bool = False,
        fmt=('rating', 'general', 'character'),
):
    """
    Convert WD14 embedding to understandable prediction result.

    character_names = [labels[i] for i in character_indexes]
    if character_mcut_enabled:
        character_probs = np.array([x[1] for x in character_names])
        character_threshold = _mcut_threshold(character_probs)
        character_threshold = max(0.15, character_threshold)
    :param emb: The 1-dim extracted embedding.
    :type emb: numpy.ndarray
    :param model_name: The name of the model to use.
    :type model_name: str
    :param general_threshold: The threshold for general tags.
    :type general_threshold: float
    :param general_mcut_enabled: If True, applies MCut thresholding to general tags.
    :type general_mcut_enabled: bool
    :param character_threshold: The threshold for character tags.
    :type character_threshold: float
    :param character_mcut_enabled: If True, applies MCut thresholding to character tags.
    :type character_mcut_enabled: bool
    :param no_underline: If True, replaces underscores in tag names with spaces.
    :type no_underline: bool
    :param drop_overlap: If True, drops overlapping tags.
    :type drop_overlap: bool
    :param fmt: Return format, default is ``('rating', 'general', 'character')``.
    :return: Prediction result based on the provided fmt.

    character_res = {x: v.item() for x, v in character_names if v > character_threshold}
    .. note::
        Only the embeddings not get normalized can be converted to understandable prediction result.

    return vreplace(
        fmt,
        {
            'rating': rating,
            'general': general_res,
            'character': character_res,
            'tag': {**general_res, **character_res},
            'embedding': embeddings[0].astype(np.float32),
            'prediction': preds[0].astype(np.float32),
        }
    Example:
        >>> import os
        >>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction
        >>>
        >>> # extract the feature embedding
        >>> embedding = get_wd14_tags('skadi.jpg', fmt='embedding')
        >>>
        >>> # convert to understandable result
        >>> rating, general, character = convert_wd14_emb_to_prediction(embedding)
        >>> # these 3 dicts will be the same as that returned by `get_wd14_tags('skadi.jpg')`
    """
    z_weights = _get_wd14_weights(model_name)
    weights, bias = z_weights['weights'], z_weights['bias']
    pred = sigmoid(emb @ weights + bias)
    return _postprocess_embedding(
        pred=pred,
        embedding=emb,
        model_name=model_name,
        general_threshold=general_threshold,
        general_mcut_enabled=general_mcut_enabled,
        character_threshold=character_threshold,
        character_mcut_enabled=character_mcut_enabled,
        no_underline=no_underline,
        drop_overlap=drop_overlap,
        fmt=fmt,
    )
Loading