Loading docs/source/api_doc/tagging/tagging_benchmark.plot.py +3 −0 Original line number Diff line number Diff line Loading @@ -59,6 +59,9 @@ if __name__ == '__main__': ('wd14-convnextv2', Wd14Benchmark("ConvNextV2")), ('wd14-vit', Wd14Benchmark("ViT")), ('wd14-moat', Wd14Benchmark("MOAT")), ('wd14-swinv2-v3', Wd14Benchmark("SwinV2_v3")), ('wd14-vit-v3', Wd14Benchmark("ViT_v3")), ('wd14-convnext-v3', Wd14Benchmark("ConvNext_v3")), ('mldanbooru', MLDanbooruBenchmark()), ], title='Benchmark for Tagging Models', Loading docs/source/api_doc/tagging/tagging_benchmark.plot.py.svg +884 −589 File changed.Preview size limit exceeded, changes collapsed. Show changes imgutils/tagging/wd14.py +144 −67 Original line number Diff line number Diff line Loading @@ -6,98 +6,178 @@ Overview: from functools import lru_cache from typing import List, Tuple import cv2 import huggingface_hub import numpy as np import onnxruntime import pandas as pd from PIL import Image from hbutils.testing.requires.version import VersionInfo from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping from ..utils import open_onnx_model def make_square(img, target_size): old_size = img.shape[:2] desired_size = max(old_size) desired_size = max(desired_size, target_size) delta_w = desired_size - old_size[1] delta_h = desired_size - old_size[0] top, bottom = delta_h // 2, delta_h - (delta_h // 2) left, right = delta_w // 2, delta_w - (delta_w // 2) color = [255, 255, 255] # noinspection PyUnresolvedReferences new_im = cv2.copyMakeBorder( img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color ) return new_im def smart_resize(img, size): # Assumes the image has already gone through make_square if img.shape[0] > size: # noinspection PyUnresolvedReferences img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) elif img.shape[0] < size: # noinspection PyUnresolvedReferences img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) return img SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" CONV_V3_MODEL_REPO = 'SmilingWolf/wd-convnext-tagger-v3' SWIN_V3_MODEL_REPO = 'SmilingWolf/wd-swinv2-tagger-v3' VIT_V3_MODEL_REPO = 'SmilingWolf/wd-vit-tagger-v3' MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" _IS_V3_SUPPORT = VersionInfo(onnxruntime.__version__) >= '1.17' MODEL_NAMES = { "SwinV2": SWIN_MODEL_REPO, "ConvNext": CONV_MODEL_REPO, "ConvNextV2": CONV2_MODEL_REPO, "ViT": VIT_MODEL_REPO, "MOAT": MOAT_MODEL_REPO, "SwinV2_v3": SWIN_V3_MODEL_REPO, "ConvNext_v3": CONV_V3_MODEL_REPO, "ViT_v3": VIT_V3_MODEL_REPO, } def _load_wd14_model(model_repo: str, model_filename: str): return open_onnx_model(huggingface_hub.hf_hub_download(model_repo, model_filename)) def _version_support_check(model_name): if model_name.endswith('_v3') and not _IS_V3_SUPPORT: raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, ' f'please upgrade it to 1.17+ version.') # pragma: no cover _KAOMOJIS = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ] @lru_cache() def _get_wd14_model(model_name): return _load_wd14_model(MODEL_NAMES[model_name], MODEL_FILENAME) """ Load an ONNX model from the Hugging Face Hub. :param model_name: The name of the model. :type model_name: str :return: The loaded ONNX model. :rtype: ONNXModel """ _version_support_check(model_name) return open_onnx_model(huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], MODEL_FILENAME)) @lru_cache() def _get_wd14_labels() -> Tuple[List[str], List[int], List[int], List[int]]: path = huggingface_hub.hf_hub_download(CONV2_MODEL_REPO, LABEL_FILENAME) def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]: """ Get labels for the WD14 model. :param model_name: The name of the model. :type model_name: str :param no_underline: If True, replaces underscores in tag names with spaces. :type no_underline: bool :return: A tuple containing the list of tag names, and lists of indexes for rating, general, and character categories. :rtype: Tuple[List[str], List[int], List[int], List[int]] """ path = huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], LABEL_FILENAME) df = pd.read_csv(path) name_series = df["name"] if no_underline: name_series = name_series.map( lambda x: x.replace("_", " ") if x not in _KAOMOJIS else x ) tag_names = name_series.tolist() tag_names = df["name"].tolist() rating_indexes = list(np.where(df["category"] == 9)[0]) general_indexes = list(np.where(df["category"] == 0)[0]) character_indexes = list(np.where(df["category"] == 4)[0]) return tag_names, rating_indexes, general_indexes, character_indexes def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", general_threshold: float = 0.35, character_threshold: float = 0.85, drop_overlap: bool = False): def _mcut_threshold(probs) -> float: """ Maximum Cut Thresholding (MCut) Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy for Multi-label Classification. In 11th International Symposium, IDA 2012 (pp. 172-183). """ sorted_probs = probs[probs.argsort()[::-1]] difs = sorted_probs[:-1] - sorted_probs[1:] t = difs.argmax() thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 return thresh def _prepare_image_for_tagging(image: ImageTyping, target_size: int): image = load_image(image, force_background='white', mode='RGB') image_shape = image.size max_dim = max(image_shape) pad_left = (max_dim - image_shape[0]) // 2 pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) padded_image.paste(image, (pad_left, pad_top)) if max_dim != target_size: padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC) image_array = np.asarray(padded_image, dtype=np.float32) image_array = image_array[:, :, ::-1] return np.expand_dims(image_array, axis=0) def get_wd14_tags( image: ImageTyping, model_name: str = 'ConvNextV2', general_threshold: float = 0.35, general_mcut_enabled: bool = False, character_threshold: float = 0.85, character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, ): """ Overview: Tagging image by wd14 v2 model. Similar to `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . :param image: Image to tagging. :param model_name: Name of the mode, should be one of the \ ``SwinV2``, ``ConvNext``, ``ConvNextV2``, ``ViT`` or ``MOAT``, default is ``ConvNextV2``. :param general_threshold: Threshold for default tags, default is ``0.35``. :param character_threshold: Threshold for character tags, default is ``0.85``. :param drop_overlap: Drop overlap tags or not, default is ``False``. :return: Tagging results for levels, features and characters. Get tags for an image with wd14 taggers. Similar to `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . :param image: The input image. :type image: ImageTyping :param model_name: The name of the model to use. :type model_name: str :param general_threshold: The threshold for general tags. :type general_threshold: float :param general_mcut_enabled: If True, applies MCut thresholding to general tags. :type general_mcut_enabled: bool :param character_threshold: The threshold for character tags. :type character_threshold: float :param character_mcut_enabled: If True, applies MCut thresholding to character tags. :type character_mcut_enabled: bool :param no_underline: If True, replaces underscores in tag names with spaces. :type no_underline: bool :param drop_overlap: If True, drops overlapping tags. :type drop_overlap: bool :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] Example: Here are some images for example Loading @@ -124,38 +204,35 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", >>> chars {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541} """ tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) model = _get_wd14_model(model_name) _, height, width, _ = model.get_inputs()[0].shape # Load image, PIL RGB to OpenCV BGR image = load_image(image, mode='RGB', force_background='white') image = np.asarray(image)[:, :, ::-1] image = make_square(image, height) image = smart_resize(image, height) image = image.astype(np.float32) image = np.expand_dims(image, 0) _, target_size, _, _ = model.get_inputs()[0].shape image = _prepare_image_for_tagging(image, target_size) input_name = model.get_inputs()[0].name label_name = model.get_outputs()[0].name probs = model.run([label_name], {input_name: image})[0] preds = model.run([label_name], {input_name: image})[0] labels = list(zip(tag_names, preds[0].astype(float))) tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels() labels = list(zip(tag_names, probs[0].astype(float).tolist())) # First 4 labels are actually ratings: pick one with argmax ratings_names = [labels[i] for i in rating_indexes] rating = dict(ratings_names) # Then we have general tags: pick anywhere prediction confidence > threshold general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = [x for x in general_names if x[1] > general_threshold] general_res = dict(general_res) if drop_overlap: general_res = drop_overlap_tags(general_res) # Everything else is characters: pick anywhere prediction confidence > threshold character_names = [labels[i] for i in character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) character_res = [x for x in character_names if x[1] > character_threshold] character_res = dict(character_res) Loading test/tagging/test_wd14.py +79 −29 Original line number Diff line number Diff line Loading @@ -34,22 +34,71 @@ class TestTaggingWd14: 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=1e-3) }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507, 'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317, 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223, 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901 }, abs=1e-3) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3) '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long_hair': 0.9451607465744019, 'breasts': 0.9867608547210693, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large_breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair_between_eyes': 0.5816333293914795, 'medium_breasts': 0.36410677433013916, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'nude': 0.9568941593170166, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'pussy': 0.9876313805580139, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'stomach': 0.6858262419700623, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508, 'pussy_juice': 0.6459053754806519, 'feet_out_of_frame': 0.3921701908111572, 'on_bed': 0.6049470901489258, 'arms_behind_head': 0.4758358597755432, 'breasts_apart': 0.38581883907318115, 'clitoris': 0.5746099948883057 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2) def test_wd14_tags_sample_no_underline(self): rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), no_underline=True) assert rating == pytest.approx({ 'general': 0.0020540356636047363, 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long hair': 0.9451607465744019, 'breasts': 0.9867608547210693, 'looking at viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair between eyes': 0.5816333293914795, 'medium breasts': 0.36410677433013916, 'very long hair': 0.811715304851532, 'closed mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'nude': 0.9568941593170166, 'red hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'pussy': 0.9876313805580139, 'spread legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'stomach': 0.6858262419700623, 'arms up': 0.9398491978645325, 'completely nude': 0.907513439655304, 'uncensored': 0.8703584671020508, 'pussy juice': 0.6459053754806519, 'feet out of frame': 0.3921701908111572, 'on bed': 0.6049470901489258, 'arms behind head': 0.4758358597755432, 'breasts apart': 0.38581883907318115, 'clitoris': 0.5746099948883057 }, abs=2e-2) assert chars == pytest.approx({'surtr (arknights)': 0.9942929744720459}, abs=2e-2) def test_wd14_tags_sample_mcut(self): rating, tags, chars = get_wd14_tags( get_testfile('nude_girl.png'), general_mcut_enabled=True, character_mcut_enabled=True, ) assert rating == pytest.approx({ 'general': 0.0020540356636047363, 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long_hair': 0.9451607465744019, 'breasts': 0.9867608547210693, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'navel': 0.9681310653686523, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'nude': 0.9568941593170166, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'pussy': 0.9876313805580139, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2) def test_wd14_tags_no_overlap(self): rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), drop_overlap=True) Loading @@ -59,18 +108,19 @@ class TestTaggingWd14: 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=1e-3) }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'red_hair': 0.9200156331062317, 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901 }, abs=1e-3) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3) '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large_breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair_between_eyes': 0.5816333293914795, 'medium_breasts': 0.36410677433013916, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'stomach': 0.6858262419700623, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508, 'pussy_juice': 0.6459053754806519, 'feet_out_of_frame': 0.3921701908111572, 'on_bed': 0.6049470901489258, 'arms_behind_head': 0.4758358597755432, 'breasts_apart': 0.38581883907318115, 'clitoris': 0.5746099948883057 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2) Loading
docs/source/api_doc/tagging/tagging_benchmark.plot.py +3 −0 Original line number Diff line number Diff line Loading @@ -59,6 +59,9 @@ if __name__ == '__main__': ('wd14-convnextv2', Wd14Benchmark("ConvNextV2")), ('wd14-vit', Wd14Benchmark("ViT")), ('wd14-moat', Wd14Benchmark("MOAT")), ('wd14-swinv2-v3', Wd14Benchmark("SwinV2_v3")), ('wd14-vit-v3', Wd14Benchmark("ViT_v3")), ('wd14-convnext-v3', Wd14Benchmark("ConvNext_v3")), ('mldanbooru', MLDanbooruBenchmark()), ], title='Benchmark for Tagging Models', Loading
docs/source/api_doc/tagging/tagging_benchmark.plot.py.svg +884 −589 File changed.Preview size limit exceeded, changes collapsed. Show changes
imgutils/tagging/wd14.py +144 −67 Original line number Diff line number Diff line Loading @@ -6,98 +6,178 @@ Overview: from functools import lru_cache from typing import List, Tuple import cv2 import huggingface_hub import numpy as np import onnxruntime import pandas as pd from PIL import Image from hbutils.testing.requires.version import VersionInfo from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping from ..utils import open_onnx_model def make_square(img, target_size): old_size = img.shape[:2] desired_size = max(old_size) desired_size = max(desired_size, target_size) delta_w = desired_size - old_size[1] delta_h = desired_size - old_size[0] top, bottom = delta_h // 2, delta_h - (delta_h // 2) left, right = delta_w // 2, delta_w - (delta_w // 2) color = [255, 255, 255] # noinspection PyUnresolvedReferences new_im = cv2.copyMakeBorder( img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color ) return new_im def smart_resize(img, size): # Assumes the image has already gone through make_square if img.shape[0] > size: # noinspection PyUnresolvedReferences img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) elif img.shape[0] < size: # noinspection PyUnresolvedReferences img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) return img SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" CONV_V3_MODEL_REPO = 'SmilingWolf/wd-convnext-tagger-v3' SWIN_V3_MODEL_REPO = 'SmilingWolf/wd-swinv2-tagger-v3' VIT_V3_MODEL_REPO = 'SmilingWolf/wd-vit-tagger-v3' MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" _IS_V3_SUPPORT = VersionInfo(onnxruntime.__version__) >= '1.17' MODEL_NAMES = { "SwinV2": SWIN_MODEL_REPO, "ConvNext": CONV_MODEL_REPO, "ConvNextV2": CONV2_MODEL_REPO, "ViT": VIT_MODEL_REPO, "MOAT": MOAT_MODEL_REPO, "SwinV2_v3": SWIN_V3_MODEL_REPO, "ConvNext_v3": CONV_V3_MODEL_REPO, "ViT_v3": VIT_V3_MODEL_REPO, } def _load_wd14_model(model_repo: str, model_filename: str): return open_onnx_model(huggingface_hub.hf_hub_download(model_repo, model_filename)) def _version_support_check(model_name): if model_name.endswith('_v3') and not _IS_V3_SUPPORT: raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, ' f'please upgrade it to 1.17+ version.') # pragma: no cover _KAOMOJIS = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ] @lru_cache() def _get_wd14_model(model_name): return _load_wd14_model(MODEL_NAMES[model_name], MODEL_FILENAME) """ Load an ONNX model from the Hugging Face Hub. :param model_name: The name of the model. :type model_name: str :return: The loaded ONNX model. :rtype: ONNXModel """ _version_support_check(model_name) return open_onnx_model(huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], MODEL_FILENAME)) @lru_cache() def _get_wd14_labels() -> Tuple[List[str], List[int], List[int], List[int]]: path = huggingface_hub.hf_hub_download(CONV2_MODEL_REPO, LABEL_FILENAME) def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]: """ Get labels for the WD14 model. :param model_name: The name of the model. :type model_name: str :param no_underline: If True, replaces underscores in tag names with spaces. :type no_underline: bool :return: A tuple containing the list of tag names, and lists of indexes for rating, general, and character categories. :rtype: Tuple[List[str], List[int], List[int], List[int]] """ path = huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], LABEL_FILENAME) df = pd.read_csv(path) name_series = df["name"] if no_underline: name_series = name_series.map( lambda x: x.replace("_", " ") if x not in _KAOMOJIS else x ) tag_names = name_series.tolist() tag_names = df["name"].tolist() rating_indexes = list(np.where(df["category"] == 9)[0]) general_indexes = list(np.where(df["category"] == 0)[0]) character_indexes = list(np.where(df["category"] == 4)[0]) return tag_names, rating_indexes, general_indexes, character_indexes def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", general_threshold: float = 0.35, character_threshold: float = 0.85, drop_overlap: bool = False): def _mcut_threshold(probs) -> float: """ Maximum Cut Thresholding (MCut) Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy for Multi-label Classification. In 11th International Symposium, IDA 2012 (pp. 172-183). """ sorted_probs = probs[probs.argsort()[::-1]] difs = sorted_probs[:-1] - sorted_probs[1:] t = difs.argmax() thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 return thresh def _prepare_image_for_tagging(image: ImageTyping, target_size: int): image = load_image(image, force_background='white', mode='RGB') image_shape = image.size max_dim = max(image_shape) pad_left = (max_dim - image_shape[0]) // 2 pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) padded_image.paste(image, (pad_left, pad_top)) if max_dim != target_size: padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC) image_array = np.asarray(padded_image, dtype=np.float32) image_array = image_array[:, :, ::-1] return np.expand_dims(image_array, axis=0) def get_wd14_tags( image: ImageTyping, model_name: str = 'ConvNextV2', general_threshold: float = 0.35, general_mcut_enabled: bool = False, character_threshold: float = 0.85, character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, ): """ Overview: Tagging image by wd14 v2 model. Similar to `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . :param image: Image to tagging. :param model_name: Name of the mode, should be one of the \ ``SwinV2``, ``ConvNext``, ``ConvNextV2``, ``ViT`` or ``MOAT``, default is ``ConvNextV2``. :param general_threshold: Threshold for default tags, default is ``0.35``. :param character_threshold: Threshold for character tags, default is ``0.85``. :param drop_overlap: Drop overlap tags or not, default is ``False``. :return: Tagging results for levels, features and characters. Get tags for an image with wd14 taggers. Similar to `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . :param image: The input image. :type image: ImageTyping :param model_name: The name of the model to use. :type model_name: str :param general_threshold: The threshold for general tags. :type general_threshold: float :param general_mcut_enabled: If True, applies MCut thresholding to general tags. :type general_mcut_enabled: bool :param character_threshold: The threshold for character tags. :type character_threshold: float :param character_mcut_enabled: If True, applies MCut thresholding to character tags. :type character_mcut_enabled: bool :param no_underline: If True, replaces underscores in tag names with spaces. :type no_underline: bool :param drop_overlap: If True, drops overlapping tags. :type drop_overlap: bool :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] Example: Here are some images for example Loading @@ -124,38 +204,35 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", >>> chars {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541} """ tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) model = _get_wd14_model(model_name) _, height, width, _ = model.get_inputs()[0].shape # Load image, PIL RGB to OpenCV BGR image = load_image(image, mode='RGB', force_background='white') image = np.asarray(image)[:, :, ::-1] image = make_square(image, height) image = smart_resize(image, height) image = image.astype(np.float32) image = np.expand_dims(image, 0) _, target_size, _, _ = model.get_inputs()[0].shape image = _prepare_image_for_tagging(image, target_size) input_name = model.get_inputs()[0].name label_name = model.get_outputs()[0].name probs = model.run([label_name], {input_name: image})[0] preds = model.run([label_name], {input_name: image})[0] labels = list(zip(tag_names, preds[0].astype(float))) tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels() labels = list(zip(tag_names, probs[0].astype(float).tolist())) # First 4 labels are actually ratings: pick one with argmax ratings_names = [labels[i] for i in rating_indexes] rating = dict(ratings_names) # Then we have general tags: pick anywhere prediction confidence > threshold general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = [x for x in general_names if x[1] > general_threshold] general_res = dict(general_res) if drop_overlap: general_res = drop_overlap_tags(general_res) # Everything else is characters: pick anywhere prediction confidence > threshold character_names = [labels[i] for i in character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) character_res = [x for x in character_names if x[1] > character_threshold] character_res = dict(character_res) Loading
test/tagging/test_wd14.py +79 −29 Original line number Diff line number Diff line Loading @@ -34,22 +34,71 @@ class TestTaggingWd14: 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=1e-3) }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507, 'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317, 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223, 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901 }, abs=1e-3) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3) '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long_hair': 0.9451607465744019, 'breasts': 0.9867608547210693, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large_breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair_between_eyes': 0.5816333293914795, 'medium_breasts': 0.36410677433013916, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'nude': 0.9568941593170166, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'pussy': 0.9876313805580139, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'stomach': 0.6858262419700623, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508, 'pussy_juice': 0.6459053754806519, 'feet_out_of_frame': 0.3921701908111572, 'on_bed': 0.6049470901489258, 'arms_behind_head': 0.4758358597755432, 'breasts_apart': 0.38581883907318115, 'clitoris': 0.5746099948883057 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2) def test_wd14_tags_sample_no_underline(self): rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), no_underline=True) assert rating == pytest.approx({ 'general': 0.0020540356636047363, 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long hair': 0.9451607465744019, 'breasts': 0.9867608547210693, 'looking at viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair between eyes': 0.5816333293914795, 'medium breasts': 0.36410677433013916, 'very long hair': 0.811715304851532, 'closed mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'nude': 0.9568941593170166, 'red hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'pussy': 0.9876313805580139, 'spread legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'stomach': 0.6858262419700623, 'arms up': 0.9398491978645325, 'completely nude': 0.907513439655304, 'uncensored': 0.8703584671020508, 'pussy juice': 0.6459053754806519, 'feet out of frame': 0.3921701908111572, 'on bed': 0.6049470901489258, 'arms behind head': 0.4758358597755432, 'breasts apart': 0.38581883907318115, 'clitoris': 0.5746099948883057 }, abs=2e-2) assert chars == pytest.approx({'surtr (arknights)': 0.9942929744720459}, abs=2e-2) def test_wd14_tags_sample_mcut(self): rating, tags, chars = get_wd14_tags( get_testfile('nude_girl.png'), general_mcut_enabled=True, character_mcut_enabled=True, ) assert rating == pytest.approx({ 'general': 0.0020540356636047363, 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long_hair': 0.9451607465744019, 'breasts': 0.9867608547210693, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'navel': 0.9681310653686523, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'nude': 0.9568941593170166, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'pussy': 0.9876313805580139, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2) def test_wd14_tags_no_overlap(self): rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), drop_overlap=True) Loading @@ -59,18 +108,19 @@ class TestTaggingWd14: 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=1e-3) }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'red_hair': 0.9200156331062317, 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901 }, abs=1e-3) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3) '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large_breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair_between_eyes': 0.5816333293914795, 'medium_breasts': 0.36410677433013916, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'stomach': 0.6858262419700623, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508, 'pussy_juice': 0.6459053754806519, 'feet_out_of_frame': 0.3921701908111572, 'on_bed': 0.6049470901489258, 'arms_behind_head': 0.4758358597755432, 'breasts_apart': 0.38581883907318115, 'clitoris': 0.5746099948883057 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2)