Unverified Commit eda6d0ba authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #72 from deepghs/dev/wd14v3

dev(narugo): add support for wd14 v3 models
parents f66b044c 4a04dcaf
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -59,6 +59,9 @@ if __name__ == '__main__':
            ('wd14-convnextv2', Wd14Benchmark("ConvNextV2")),
            ('wd14-vit', Wd14Benchmark("ViT")),
            ('wd14-moat', Wd14Benchmark("MOAT")),
            ('wd14-swinv2-v3', Wd14Benchmark("SwinV2_v3")),
            ('wd14-vit-v3', Wd14Benchmark("ViT_v3")),
            ('wd14-convnext-v3', Wd14Benchmark("ConvNext_v3")),
            ('mldanbooru', MLDanbooruBenchmark()),
        ],
        title='Benchmark for Tagging Models',
+884 −589

File changed.

Preview size limit exceeded, changes collapsed.

+144 −67
Original line number Diff line number Diff line
@@ -6,98 +6,178 @@ Overview:
from functools import lru_cache
from typing import List, Tuple

import cv2
import huggingface_hub
import numpy as np
import onnxruntime
import pandas as pd
from PIL import Image
from hbutils.testing.requires.version import VersionInfo

from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model


def make_square(img, target_size):
    old_size = img.shape[:2]
    desired_size = max(old_size)
    desired_size = max(desired_size, target_size)

    delta_w = desired_size - old_size[1]
    delta_h = desired_size - old_size[0]
    top, bottom = delta_h // 2, delta_h - (delta_h // 2)
    left, right = delta_w // 2, delta_w - (delta_w // 2)

    color = [255, 255, 255]
    # noinspection PyUnresolvedReferences
    new_im = cv2.copyMakeBorder(
        img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
    )
    return new_im


def smart_resize(img, size):
    # Assumes the image has already gone through make_square
    if img.shape[0] > size:
        # noinspection PyUnresolvedReferences
        img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
    elif img.shape[0] < size:
        # noinspection PyUnresolvedReferences
        img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
    return img


SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
CONV_V3_MODEL_REPO = 'SmilingWolf/wd-convnext-tagger-v3'
SWIN_V3_MODEL_REPO = 'SmilingWolf/wd-swinv2-tagger-v3'
VIT_V3_MODEL_REPO = 'SmilingWolf/wd-vit-tagger-v3'
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"

_IS_V3_SUPPORT = VersionInfo(onnxruntime.__version__) >= '1.17'

MODEL_NAMES = {
    "SwinV2": SWIN_MODEL_REPO,
    "ConvNext": CONV_MODEL_REPO,
    "ConvNextV2": CONV2_MODEL_REPO,
    "ViT": VIT_MODEL_REPO,
    "MOAT": MOAT_MODEL_REPO,

    "SwinV2_v3": SWIN_V3_MODEL_REPO,
    "ConvNext_v3": CONV_V3_MODEL_REPO,
    "ViT_v3": VIT_V3_MODEL_REPO,
}


def _load_wd14_model(model_repo: str, model_filename: str):
    return open_onnx_model(huggingface_hub.hf_hub_download(model_repo, model_filename))
def _version_support_check(model_name):
    if model_name.endswith('_v3') and not _IS_V3_SUPPORT:
        raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, '
                               f'please upgrade it to 1.17+ version.')  # pragma: no cover


_KAOMOJIS = [
    "0_0",
    "(o)_(o)",
    "+_+",
    "+_-",
    "._.",
    "<o>_<o>",
    "<|>_<|>",
    "=_=",
    ">_<",
    "3_3",
    "6_9",
    ">_o",
    "@_@",
    "^_^",
    "o_o",
    "u_u",
    "x_x",
    "|_|",
    "||_||",
]


@lru_cache()
def _get_wd14_model(model_name):
    return _load_wd14_model(MODEL_NAMES[model_name], MODEL_FILENAME)
    """
    Load an ONNX model from the Hugging Face Hub.

    :param model_name: The name of the model.
    :type model_name: str
    :return: The loaded ONNX model.
    :rtype: ONNXModel
    """
    _version_support_check(model_name)
    return open_onnx_model(huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], MODEL_FILENAME))


@lru_cache()
def _get_wd14_labels() -> Tuple[List[str], List[int], List[int], List[int]]:
    path = huggingface_hub.hf_hub_download(CONV2_MODEL_REPO, LABEL_FILENAME)
def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]:
    """
    Get labels for the WD14 model.

    :param model_name: The name of the model.
    :type model_name: str
    :param no_underline: If True, replaces underscores in tag names with spaces.
    :type no_underline: bool
    :return: A tuple containing the list of tag names, and lists of indexes for rating, general, and character categories.
    :rtype: Tuple[List[str], List[int], List[int], List[int]]
    """
    path = huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], LABEL_FILENAME)
    df = pd.read_csv(path)
    name_series = df["name"]
    if no_underline:
        name_series = name_series.map(
            lambda x: x.replace("_", " ") if x not in _KAOMOJIS else x
        )
    tag_names = name_series.tolist()

    tag_names = df["name"].tolist()
    rating_indexes = list(np.where(df["category"] == 9)[0])
    general_indexes = list(np.where(df["category"] == 0)[0])
    character_indexes = list(np.where(df["category"] == 4)[0])
    return tag_names, rating_indexes, general_indexes, character_indexes


def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
                  general_threshold: float = 0.35, character_threshold: float = 0.85,
                  drop_overlap: bool = False):
def _mcut_threshold(probs) -> float:
    """
    Maximum Cut Thresholding (MCut)
    Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
     for Multi-label Classification. In 11th International Symposium, IDA 2012
     (pp. 172-183).
    """
    sorted_probs = probs[probs.argsort()[::-1]]
    difs = sorted_probs[:-1] - sorted_probs[1:]
    t = difs.argmax()
    thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
    return thresh


def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
    image = load_image(image, force_background='white', mode='RGB')
    image_shape = image.size
    max_dim = max(image_shape)
    pad_left = (max_dim - image_shape[0]) // 2
    pad_top = (max_dim - image_shape[1]) // 2

    padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
    padded_image.paste(image, (pad_left, pad_top))

    if max_dim != target_size:
        padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)

    image_array = np.asarray(padded_image, dtype=np.float32)
    image_array = image_array[:, :, ::-1]
    return np.expand_dims(image_array, axis=0)


def get_wd14_tags(
        image: ImageTyping,
        model_name: str = 'ConvNextV2',
        general_threshold: float = 0.35,
        general_mcut_enabled: bool = False,
        character_threshold: float = 0.85,
        character_mcut_enabled: bool = False,
        no_underline: bool = False,
        drop_overlap: bool = False,
):
    """
    Overview:
        Tagging image by wd14 v2 model. Similar to
        `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ .

    :param image: Image to tagging.
    :param model_name: Name of the mode, should be one of the \
        ``SwinV2``, ``ConvNext``, ``ConvNextV2``, ``ViT`` or ``MOAT``, default is ``ConvNextV2``.
    :param general_threshold: Threshold for default tags, default is ``0.35``.
    :param character_threshold: Threshold for character tags, default is ``0.85``.
    :param drop_overlap: Drop overlap tags or not, default is ``False``.
    :return: Tagging results for levels, features and characters.
        Get tags for an image with wd14 taggers.
        Similar to `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ .

    :param image: The input image.
    :type image: ImageTyping
    :param model_name: The name of the model to use.
    :type model_name: str
    :param general_threshold: The threshold for general tags.
    :type general_threshold: float
    :param general_mcut_enabled: If True, applies MCut thresholding to general tags.
    :type general_mcut_enabled: bool
    :param character_threshold: The threshold for character tags.
    :type character_threshold: float
    :param character_mcut_enabled: If True, applies MCut thresholding to character tags.
    :type character_mcut_enabled: bool
    :param no_underline: If True, replaces underscores in tag names with spaces.
    :type no_underline: bool
    :param drop_overlap: If True, drops overlapping tags.
    :type drop_overlap: bool
    :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities.
    :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]

    Example:
        Here are some images for example
@@ -124,38 +204,35 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
        >>> chars
        {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541}
    """
    tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline)
    model = _get_wd14_model(model_name)
    _, height, width, _ = model.get_inputs()[0].shape

    # Load image, PIL RGB to OpenCV BGR
    image = load_image(image, mode='RGB', force_background='white')
    image = np.asarray(image)[:, :, ::-1]

    image = make_square(image, height)
    image = smart_resize(image, height)
    image = image.astype(np.float32)
    image = np.expand_dims(image, 0)
    _, target_size, _, _ = model.get_inputs()[0].shape
    image = _prepare_image_for_tagging(image, target_size)

    input_name = model.get_inputs()[0].name
    label_name = model.get_outputs()[0].name
    probs = model.run([label_name], {input_name: image})[0]
    preds = model.run([label_name], {input_name: image})[0]
    labels = list(zip(tag_names, preds[0].astype(float)))

    tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels()
    labels = list(zip(tag_names, probs[0].astype(float).tolist()))

    # First 4 labels are actually ratings: pick one with argmax
    ratings_names = [labels[i] for i in rating_indexes]
    rating = dict(ratings_names)

    # Then we have general tags: pick anywhere prediction confidence > threshold
    general_names = [labels[i] for i in general_indexes]
    if general_mcut_enabled:
        general_probs = np.array([x[1] for x in general_names])
        general_threshold = _mcut_threshold(general_probs)

    general_res = [x for x in general_names if x[1] > general_threshold]
    general_res = dict(general_res)
    if drop_overlap:
        general_res = drop_overlap_tags(general_res)

    # Everything else is characters: pick anywhere prediction confidence > threshold
    character_names = [labels[i] for i in character_indexes]
    if character_mcut_enabled:
        character_probs = np.array([x[1] for x in character_names])
        character_threshold = _mcut_threshold(character_probs)
        character_threshold = max(0.15, character_threshold)

    character_res = [x for x in character_names if x[1] > character_threshold]
    character_res = dict(character_res)

+79 −29
Original line number Diff line number Diff line
@@ -34,22 +34,71 @@ class TestTaggingWd14:
            'sensitive': 0.0080718994140625,
            'questionable': 0.003170192241668701,
            'explicit': 0.984081506729126,
        }, abs=1e-3)
        }, abs=2e-2)
        assert tags == pytest.approx({
            '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507,
            'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438,
            'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896,
            'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697,
            'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648,
            'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317,
            'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223,
            'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056,
            'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686,
            'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291,
            'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727,
            'clitoris': 0.5310801267623901
        }, abs=1e-3)
        assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3)
            '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long_hair': 0.9451607465744019,
            'breasts': 0.9867608547210693, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516,
            'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large_breasts': 0.5059964656829834,
            'navel': 0.9681310653686523, 'hair_between_eyes': 0.5816333293914795, 'medium_breasts': 0.36410677433013916,
            'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767,
            'purple_eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'nude': 0.9568941593170166,
            'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571,
            'pussy': 0.9876313805580139, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608,
            'stomach': 0.6858262419700623, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304,
            'uncensored': 0.8703584671020508, 'pussy_juice': 0.6459053754806519,
            'feet_out_of_frame': 0.3921701908111572, 'on_bed': 0.6049470901489258,
            'arms_behind_head': 0.4758358597755432, 'breasts_apart': 0.38581883907318115,
            'clitoris': 0.5746099948883057
        }, abs=2e-2)
        assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2)

    def test_wd14_tags_sample_no_underline(self):
        rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), no_underline=True)
        assert rating == pytest.approx({
            'general': 0.0020540356636047363,
            'sensitive': 0.0080718994140625,
            'questionable': 0.003170192241668701,
            'explicit': 0.984081506729126,
        }, abs=2e-2)
        assert tags == pytest.approx({
            '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long hair': 0.9451607465744019,
            'breasts': 0.9867608547210693, 'looking at viewer': 0.9200493693351746, 'blush': 0.8876285552978516,
            'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large breasts': 0.5059964656829834,
            'navel': 0.9681310653686523, 'hair between eyes': 0.5816333293914795, 'medium breasts': 0.36410677433013916,
            'very long hair': 0.811715304851532, 'closed mouth': 0.9338403940200806, 'nipples': 0.9715133905410767,
            'purple eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'nude': 0.9568941593170166,
            'red hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571,
            'pussy': 0.9876313805580139, 'spread legs': 0.9634276628494263, 'armpits': 0.9116500616073608,
            'stomach': 0.6858262419700623, 'arms up': 0.9398491978645325, 'completely nude': 0.907513439655304,
            'uncensored': 0.8703584671020508, 'pussy juice': 0.6459053754806519,
            'feet out of frame': 0.3921701908111572, 'on bed': 0.6049470901489258,
            'arms behind head': 0.4758358597755432, 'breasts apart': 0.38581883907318115,
            'clitoris': 0.5746099948883057
        }, abs=2e-2)
        assert chars == pytest.approx({'surtr (arknights)': 0.9942929744720459}, abs=2e-2)

    def test_wd14_tags_sample_mcut(self):
        rating, tags, chars = get_wd14_tags(
            get_testfile('nude_girl.png'),
            general_mcut_enabled=True,
            character_mcut_enabled=True,
        )
        assert rating == pytest.approx({
            'general': 0.0020540356636047363,
            'sensitive': 0.0080718994140625,
            'questionable': 0.003170192241668701,
            'explicit': 0.984081506729126,
        }, abs=2e-2)
        assert tags == pytest.approx({
            '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long_hair': 0.9451607465744019,
            'breasts': 0.9867608547210693, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516,
            'navel': 0.9681310653686523, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806,
            'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'nude': 0.9568941593170166,
            'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571,
            'pussy': 0.9876313805580139, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608,
            'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508
        }, abs=2e-2)
        assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2)

    def test_wd14_tags_no_overlap(self):
        rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), drop_overlap=True)
@@ -59,18 +108,19 @@ class TestTaggingWd14:
            'sensitive': 0.0080718994140625,
            'questionable': 0.003170192241668701,
            'explicit': 0.984081506729126,
        }, abs=1e-3)
        }, abs=2e-2)
        assert tags == pytest.approx({
            '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698,
            'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605,
            'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948,
            'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633,
            'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'red_hair': 0.9200156331062317,
            'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'spread_legs': 0.9603149890899658,
            'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423,
            'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345,
            'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291,
            'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727,
            'clitoris': 0.5310801267623901
        }, abs=1e-3)
        assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3)
            '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'looking_at_viewer': 0.9200493693351746,
            'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614,
            'large_breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair_between_eyes': 0.5816333293914795,
            'medium_breasts': 0.36410677433013916, 'very_long_hair': 0.811715304851532,
            'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032,
            'collarbone': 0.573296308517456, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621,
            'horns': 0.973071277141571, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608,
            'stomach': 0.6858262419700623, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304,
            'uncensored': 0.8703584671020508, 'pussy_juice': 0.6459053754806519,
            'feet_out_of_frame': 0.3921701908111572, 'on_bed': 0.6049470901489258,
            'arms_behind_head': 0.4758358597755432, 'breasts_apart': 0.38581883907318115,
            'clitoris': 0.5746099948883057
        }, abs=2e-2)
        assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2)