Commit fac816fb authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add convert from emb to pred

parent 06278397
Loading
Loading
Loading
Loading
+98 −34
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ Overview:
    `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ .
"""
from functools import lru_cache
from typing import List, Tuple, Dict
from typing import List, Tuple

import numpy as np
import onnxruntime
@@ -15,8 +15,8 @@ from huggingface_hub import hf_hub_download

from .format import remove_underline
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping, has_alpha_channel
from ..utils import open_onnx_model, vreplace
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model, vreplace, sigmoid

SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
@@ -75,6 +75,15 @@ def _get_wd14_model(model_name):
    ))


@lru_cache()
def _get_wd14_weights(model_name):
    _version_support_check(model_name)
    return np.load(hf_hub_download(
        repo_id='deepghs/wd14_tagger_with_embeddings',
        filename=f'{MODEL_NAMES[model_name]}/inv.npz',
    ))


@lru_cache()
def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]:
    """
@@ -135,6 +144,52 @@ def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
    return np.expand_dims(image_array, axis=0)


def _postprocess_embedding(
        pred, embedding,
        model_name: str = _DEFAULT_MODEL_NAME,
        general_threshold: float = 0.35,
        general_mcut_enabled: bool = False,
        character_threshold: float = 0.85,
        character_mcut_enabled: bool = False,
        no_underline: bool = False,
        drop_overlap: bool = False,
        fmt=('rating', 'general', 'character'),
):
    tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline)
    labels = list(zip(tag_names, pred.astype(float)))

    rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes}

    general_names = [labels[i] for i in general_indexes]
    if general_mcut_enabled:
        general_probs = np.array([x[1] for x in general_names])
        general_threshold = _mcut_threshold(general_probs)

    general_res = {x: v.item() for x, v in general_names if v > general_threshold}
    if drop_overlap:
        general_res = drop_overlap_tags(general_res)

    character_names = [labels[i] for i in character_indexes]
    if character_mcut_enabled:
        character_probs = np.array([x[1] for x in character_names])
        character_threshold = _mcut_threshold(character_probs)
        character_threshold = max(0.15, character_threshold)

    character_res = {x: v.item() for x, v in character_names if v > character_threshold}

    return vreplace(
        fmt,
        {
            'rating': rating,
            'general': general_res,
            'character': character_res,
            'tag': {**general_res, **character_res},
            'embedding': embedding.astype(np.float32),
            'prediction': pred.astype(np.float32),
        }
    )


def get_wd14_tags(
        image: ImageTyping,
        model_name: str = _DEFAULT_MODEL_NAME,
@@ -170,8 +225,8 @@ def get_wd14_tags(
    :param fmt: Return format, default is ``('rating', 'general', 'character')``.
        ``embedding`` is also supported for feature extraction.
    :type fmt: Any
    :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities.
    :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]
    :return: Prediction result based on the provided ``fmt``.
             In the default case, it should be a tuple of ``rating``, ``general`` and ``character``.

    .. note::
        About ``fmt`` argument, these are the available names:
@@ -208,7 +263,7 @@ def get_wd14_tags(
        >>> chars
        {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541}
    """
    tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline)

    model = _get_wd14_model(model_name)
    _, target_size, _, _ = model.get_inputs()[0].shape
    image = _prepare_image_for_tagging(image, target_size)
@@ -218,35 +273,44 @@ def get_wd14_tags(
    label_name = model.get_outputs()[0].name
    emb_name = model.get_outputs()[1].name
    preds, embeddings = model.run([label_name, emb_name], {input_name: image})
    labels = list(zip(tag_names, preds[0].astype(float)))

    rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes}

    general_names = [labels[i] for i in general_indexes]
    if general_mcut_enabled:
        general_probs = np.array([x[1] for x in general_names])
        general_threshold = _mcut_threshold(general_probs)

    general_res = {x: v.item() for x, v in general_names if v > general_threshold}
    if drop_overlap:
        general_res = drop_overlap_tags(general_res)

    character_names = [labels[i] for i in character_indexes]
    if character_mcut_enabled:
        character_probs = np.array([x[1] for x in character_names])
        character_threshold = _mcut_threshold(character_probs)
        character_threshold = max(0.15, character_threshold)
    return _postprocess_embedding(
        pred=preds[0],
        embedding=embeddings[0],
        model_name=model_name,
        general_threshold=general_threshold,
        general_mcut_enabled=general_mcut_enabled,
        character_threshold=character_threshold,
        character_mcut_enabled=character_mcut_enabled,
        no_underline=no_underline,
        drop_overlap=drop_overlap,
        fmt=fmt,
    )

    character_res = {x: v.item() for x, v in character_names if v > character_threshold}

    return vreplace(
        fmt,
        {
            'rating': rating,
            'general': general_res,
            'character': character_res,
            'tag': {**general_res, **character_res},
            'embedding': embeddings[0].astype(np.float32),
            'prediction': preds[0].astype(np.float32),
        }
def convert_wd14_emb_to_prediction(
        emb: np.ndarray,
        model_name: str = _DEFAULT_MODEL_NAME,
        general_threshold: float = 0.35,
        general_mcut_enabled: bool = False,
        character_threshold: float = 0.85,
        character_mcut_enabled: bool = False,
        no_underline: bool = False,
        drop_overlap: bool = False,
        fmt=('rating', 'general', 'character'),
):
    z_weights = _get_wd14_weights(model_name)
    weights, bias = z_weights['weights'], z_weights['bias']
    pred = sigmoid(emb @ weights + bias)
    return _postprocess_embedding(
        pred=pred,
        embedding=emb,
        model_name=model_name,
        general_threshold=general_threshold,
        general_mcut_enabled=general_mcut_enabled,
        character_threshold=character_threshold,
        character_mcut_enabled=character_mcut_enabled,
        no_underline=no_underline,
        drop_overlap=drop_overlap,
        fmt=fmt,
    )
+1 −0
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ Overview:
"""
from .area import *
from .format import *
from .func import *
from .onnxruntime import *
from .storage import *
from .tqdm_ import *

imgutils/utils/func.py

0 → 100644
+7 −0
Original line number Diff line number Diff line
import numpy as np

__all__ = ['sigmoid']


def sigmoid(x):
    return 1 / (1 + np.exp(-x))