Loading imgutils/tagging/wd14.py +94 −54 Original line number Diff line number Diff line Loading @@ -6,50 +6,23 @@ Overview: from functools import lru_cache from typing import List, Tuple import cv2 import huggingface_hub import numpy as np import pandas as pd from PIL import Image from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping from ..utils import open_onnx_model def make_square(img, target_size): old_size = img.shape[:2] desired_size = max(old_size) desired_size = max(desired_size, target_size) delta_w = desired_size - old_size[1] delta_h = desired_size - old_size[0] top, bottom = delta_h // 2, delta_h - (delta_h // 2) left, right = delta_w // 2, delta_w - (delta_w // 2) color = [255, 255, 255] # noinspection PyUnresolvedReferences new_im = cv2.copyMakeBorder( img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color ) return new_im def smart_resize(img, size): # Assumes the image has already gone through make_square if img.shape[0] > size: # noinspection PyUnresolvedReferences img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) elif img.shape[0] < size: # noinspection PyUnresolvedReferences img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) return img SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" CONV_V3_MODEL_REPO = 'SmilingWolf/wd-convnext-tagger-v3' SWIN_V3_MODEL_REPO = 'SmilingWolf/wd-swinv2-tagger-v3' VIT_V3_MODEL_REPO = 'SmilingWolf/wd-vit-tagger-v3' MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" Loading @@ -59,8 +32,34 @@ MODEL_NAMES = { "ConvNextV2": CONV2_MODEL_REPO, "ViT": VIT_MODEL_REPO, "MOAT": MOAT_MODEL_REPO, "SwinV2_v3": SWIN_V3_MODEL_REPO, "ConvNext_v3": CONV_V3_MODEL_REPO, "ViT_v3": VIT_V3_MODEL_REPO, } _KAOMOJIS = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ] def _load_wd14_model(model_repo: str, model_filename: str): return open_onnx_model(huggingface_hub.hf_hub_download(model_repo, model_filename)) Loading @@ -72,20 +71,64 @@ def _get_wd14_model(model_name): @lru_cache() def _get_wd14_labels() -> Tuple[List[str], List[int], List[int], List[int]]: path = huggingface_hub.hf_hub_download(CONV2_MODEL_REPO, LABEL_FILENAME) def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]: path = huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], LABEL_FILENAME) df = pd.read_csv(path) name_series = df["name"] if no_underline: name_series = name_series.map( lambda x: x.replace("_", " ") if x not in _KAOMOJIS else x ) tag_names = name_series.tolist() tag_names = df["name"].tolist() rating_indexes = list(np.where(df["category"] == 9)[0]) general_indexes = list(np.where(df["category"] == 0)[0]) character_indexes = list(np.where(df["category"] == 4)[0]) return tag_names, rating_indexes, general_indexes, character_indexes def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", general_threshold: float = 0.35, character_threshold: float = 0.85, drop_overlap: bool = False): def _mcut_threshold(probs) -> float: """ Maximum Cut Thresholding (MCut) Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy for Multi-label Classification. In 11th International Symposium, IDA 2012 (pp. 172-183). """ sorted_probs = probs[probs.argsort()[::-1]] difs = sorted_probs[:-1] - sorted_probs[1:] t = difs.argmax() thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 return thresh def _prepare_image_for_tagging(image: ImageTyping, target_size: int): image = load_image(image, force_background='white', mode='RGB') image_shape = image.size max_dim = max(image_shape) pad_left = (max_dim - image_shape[0]) // 2 pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) padded_image.paste(image, (pad_left, pad_top)) if max_dim != target_size: padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC) image_array = np.asarray(padded_image, dtype=np.float32) image_array = image_array[:, :, ::-1] return np.expand_dims(image_array, axis=0) def get_wd14_tags( image: ImageTyping, model_name: str = 'ConvNextV2', general_threshold: float = 0.35, general_mcut_enabled: bool = False, character_threshold: float = 0.85, character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, ): """ Overview: Tagging image by wd14 v2 model. Similar to Loading Loading @@ -124,38 +167,35 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", >>> chars {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541} """ tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) model = _get_wd14_model(model_name) _, height, width, _ = model.get_inputs()[0].shape # Load image, PIL RGB to OpenCV BGR image = load_image(image, mode='RGB', force_background='white') image = np.asarray(image)[:, :, ::-1] image = make_square(image, height) image = smart_resize(image, height) image = image.astype(np.float32) image = np.expand_dims(image, 0) _, target_size, _, _ = model.get_inputs()[0].shape image = _prepare_image_for_tagging(image, target_size) input_name = model.get_inputs()[0].name label_name = model.get_outputs()[0].name probs = model.run([label_name], {input_name: image})[0] tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels() labels = list(zip(tag_names, probs[0].astype(float).tolist())) preds = model.run([label_name], {input_name: image})[0] labels = list(zip(tag_names, preds[0].astype(float))) # First 4 labels are actually ratings: pick one with argmax ratings_names = [labels[i] for i in rating_indexes] rating = dict(ratings_names) # Then we have general tags: pick anywhere prediction confidence > threshold general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = [x for x in general_names if x[1] > general_threshold] general_res = dict(general_res) if drop_overlap: general_res = drop_overlap_tags(general_res) # Everything else is characters: pick anywhere prediction confidence > threshold character_names = [labels[i] for i in character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) character_res = [x for x in character_names if x[1] > character_threshold] character_res = dict(character_res) Loading test/tagging/test_wd14.py +79 −29 Original line number Diff line number Diff line Loading @@ -34,22 +34,71 @@ class TestTaggingWd14: 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=1e-3) }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507, 'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317, 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223, 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901 }, abs=1e-3) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3) '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long_hair': 0.9451607465744019, 'breasts': 0.9867608547210693, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large_breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair_between_eyes': 0.5816333293914795, 'medium_breasts': 0.36410677433013916, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'nude': 0.9568941593170166, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'pussy': 0.9876313805580139, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'stomach': 0.6858262419700623, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508, 'pussy_juice': 0.6459053754806519, 'feet_out_of_frame': 0.3921701908111572, 'on_bed': 0.6049470901489258, 'arms_behind_head': 0.4758358597755432, 'breasts_apart': 0.38581883907318115, 'clitoris': 0.5746099948883057 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2) def test_wd14_tags_sample_no_underline(self): rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), no_underline=True) assert rating == pytest.approx({ 'general': 0.0020540356636047363, 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long hair': 0.9451607465744019, 'breasts': 0.9867608547210693, 'looking at viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair between eyes': 0.5816333293914795, 'medium breasts': 0.36410677433013916, 'very long hair': 0.811715304851532, 'closed mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'nude': 0.9568941593170166, 'red hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'pussy': 0.9876313805580139, 'spread legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'stomach': 0.6858262419700623, 'arms up': 0.9398491978645325, 'completely nude': 0.907513439655304, 'uncensored': 0.8703584671020508, 'pussy juice': 0.6459053754806519, 'feet out of frame': 0.3921701908111572, 'on bed': 0.6049470901489258, 'arms behind head': 0.4758358597755432, 'breasts apart': 0.38581883907318115, 'clitoris': 0.5746099948883057 }, abs=2e-2) assert chars == pytest.approx({'surtr (arknights)': 0.9942929744720459}, abs=2e-2) def test_wd14_tags_sample_mcut(self): rating, tags, chars = get_wd14_tags( get_testfile('nude_girl.png'), general_mcut_enabled=True, character_mcut_enabled=True, ) assert rating == pytest.approx({ 'general': 0.0020540356636047363, 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long_hair': 0.9451607465744019, 'breasts': 0.9867608547210693, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'navel': 0.9681310653686523, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'nude': 0.9568941593170166, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'pussy': 0.9876313805580139, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2) def test_wd14_tags_no_overlap(self): rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), drop_overlap=True) Loading @@ -59,18 +108,19 @@ class TestTaggingWd14: 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=1e-3) }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'red_hair': 0.9200156331062317, 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901 }, abs=1e-3) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3) '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large_breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair_between_eyes': 0.5816333293914795, 'medium_breasts': 0.36410677433013916, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'stomach': 0.6858262419700623, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508, 'pussy_juice': 0.6459053754806519, 'feet_out_of_frame': 0.3921701908111572, 'on_bed': 0.6049470901489258, 'arms_behind_head': 0.4758358597755432, 'breasts_apart': 0.38581883907318115, 'clitoris': 0.5746099948883057 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2) Loading
imgutils/tagging/wd14.py +94 −54 Original line number Diff line number Diff line Loading @@ -6,50 +6,23 @@ Overview: from functools import lru_cache from typing import List, Tuple import cv2 import huggingface_hub import numpy as np import pandas as pd from PIL import Image from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping from ..utils import open_onnx_model def make_square(img, target_size): old_size = img.shape[:2] desired_size = max(old_size) desired_size = max(desired_size, target_size) delta_w = desired_size - old_size[1] delta_h = desired_size - old_size[0] top, bottom = delta_h // 2, delta_h - (delta_h // 2) left, right = delta_w // 2, delta_w - (delta_w // 2) color = [255, 255, 255] # noinspection PyUnresolvedReferences new_im = cv2.copyMakeBorder( img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color ) return new_im def smart_resize(img, size): # Assumes the image has already gone through make_square if img.shape[0] > size: # noinspection PyUnresolvedReferences img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA) elif img.shape[0] < size: # noinspection PyUnresolvedReferences img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC) return img SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" CONV_V3_MODEL_REPO = 'SmilingWolf/wd-convnext-tagger-v3' SWIN_V3_MODEL_REPO = 'SmilingWolf/wd-swinv2-tagger-v3' VIT_V3_MODEL_REPO = 'SmilingWolf/wd-vit-tagger-v3' MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" Loading @@ -59,8 +32,34 @@ MODEL_NAMES = { "ConvNextV2": CONV2_MODEL_REPO, "ViT": VIT_MODEL_REPO, "MOAT": MOAT_MODEL_REPO, "SwinV2_v3": SWIN_V3_MODEL_REPO, "ConvNext_v3": CONV_V3_MODEL_REPO, "ViT_v3": VIT_V3_MODEL_REPO, } _KAOMOJIS = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ] def _load_wd14_model(model_repo: str, model_filename: str): return open_onnx_model(huggingface_hub.hf_hub_download(model_repo, model_filename)) Loading @@ -72,20 +71,64 @@ def _get_wd14_model(model_name): @lru_cache() def _get_wd14_labels() -> Tuple[List[str], List[int], List[int], List[int]]: path = huggingface_hub.hf_hub_download(CONV2_MODEL_REPO, LABEL_FILENAME) def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]: path = huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], LABEL_FILENAME) df = pd.read_csv(path) name_series = df["name"] if no_underline: name_series = name_series.map( lambda x: x.replace("_", " ") if x not in _KAOMOJIS else x ) tag_names = name_series.tolist() tag_names = df["name"].tolist() rating_indexes = list(np.where(df["category"] == 9)[0]) general_indexes = list(np.where(df["category"] == 0)[0]) character_indexes = list(np.where(df["category"] == 4)[0]) return tag_names, rating_indexes, general_indexes, character_indexes def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", general_threshold: float = 0.35, character_threshold: float = 0.85, drop_overlap: bool = False): def _mcut_threshold(probs) -> float: """ Maximum Cut Thresholding (MCut) Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy for Multi-label Classification. In 11th International Symposium, IDA 2012 (pp. 172-183). """ sorted_probs = probs[probs.argsort()[::-1]] difs = sorted_probs[:-1] - sorted_probs[1:] t = difs.argmax() thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 return thresh def _prepare_image_for_tagging(image: ImageTyping, target_size: int): image = load_image(image, force_background='white', mode='RGB') image_shape = image.size max_dim = max(image_shape) pad_left = (max_dim - image_shape[0]) // 2 pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) padded_image.paste(image, (pad_left, pad_top)) if max_dim != target_size: padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC) image_array = np.asarray(padded_image, dtype=np.float32) image_array = image_array[:, :, ::-1] return np.expand_dims(image_array, axis=0) def get_wd14_tags( image: ImageTyping, model_name: str = 'ConvNextV2', general_threshold: float = 0.35, general_mcut_enabled: bool = False, character_threshold: float = 0.85, character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, ): """ Overview: Tagging image by wd14 v2 model. Similar to Loading Loading @@ -124,38 +167,35 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2", >>> chars {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541} """ tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) model = _get_wd14_model(model_name) _, height, width, _ = model.get_inputs()[0].shape # Load image, PIL RGB to OpenCV BGR image = load_image(image, mode='RGB', force_background='white') image = np.asarray(image)[:, :, ::-1] image = make_square(image, height) image = smart_resize(image, height) image = image.astype(np.float32) image = np.expand_dims(image, 0) _, target_size, _, _ = model.get_inputs()[0].shape image = _prepare_image_for_tagging(image, target_size) input_name = model.get_inputs()[0].name label_name = model.get_outputs()[0].name probs = model.run([label_name], {input_name: image})[0] tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels() labels = list(zip(tag_names, probs[0].astype(float).tolist())) preds = model.run([label_name], {input_name: image})[0] labels = list(zip(tag_names, preds[0].astype(float))) # First 4 labels are actually ratings: pick one with argmax ratings_names = [labels[i] for i in rating_indexes] rating = dict(ratings_names) # Then we have general tags: pick anywhere prediction confidence > threshold general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = [x for x in general_names if x[1] > general_threshold] general_res = dict(general_res) if drop_overlap: general_res = drop_overlap_tags(general_res) # Everything else is characters: pick anywhere prediction confidence > threshold character_names = [labels[i] for i in character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) character_res = [x for x in character_names if x[1] > character_threshold] character_res = dict(character_res) Loading
test/tagging/test_wd14.py +79 −29 Original line number Diff line number Diff line Loading @@ -34,22 +34,71 @@ class TestTaggingWd14: 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=1e-3) }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507, 'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317, 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223, 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901 }, abs=1e-3) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3) '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long_hair': 0.9451607465744019, 'breasts': 0.9867608547210693, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large_breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair_between_eyes': 0.5816333293914795, 'medium_breasts': 0.36410677433013916, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'nude': 0.9568941593170166, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'pussy': 0.9876313805580139, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'stomach': 0.6858262419700623, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508, 'pussy_juice': 0.6459053754806519, 'feet_out_of_frame': 0.3921701908111572, 'on_bed': 0.6049470901489258, 'arms_behind_head': 0.4758358597755432, 'breasts_apart': 0.38581883907318115, 'clitoris': 0.5746099948883057 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2) def test_wd14_tags_sample_no_underline(self): rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), no_underline=True) assert rating == pytest.approx({ 'general': 0.0020540356636047363, 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long hair': 0.9451607465744019, 'breasts': 0.9867608547210693, 'looking at viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair between eyes': 0.5816333293914795, 'medium breasts': 0.36410677433013916, 'very long hair': 0.811715304851532, 'closed mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'nude': 0.9568941593170166, 'red hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'pussy': 0.9876313805580139, 'spread legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'stomach': 0.6858262419700623, 'arms up': 0.9398491978645325, 'completely nude': 0.907513439655304, 'uncensored': 0.8703584671020508, 'pussy juice': 0.6459053754806519, 'feet out of frame': 0.3921701908111572, 'on bed': 0.6049470901489258, 'arms behind head': 0.4758358597755432, 'breasts apart': 0.38581883907318115, 'clitoris': 0.5746099948883057 }, abs=2e-2) assert chars == pytest.approx({'surtr (arknights)': 0.9942929744720459}, abs=2e-2) def test_wd14_tags_sample_mcut(self): rating, tags, chars = get_wd14_tags( get_testfile('nude_girl.png'), general_mcut_enabled=True, character_mcut_enabled=True, ) assert rating == pytest.approx({ 'general': 0.0020540356636047363, 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'long_hair': 0.9451607465744019, 'breasts': 0.9867608547210693, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'navel': 0.9681310653686523, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'nude': 0.9568941593170166, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'pussy': 0.9876313805580139, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2) def test_wd14_tags_no_overlap(self): rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'), drop_overlap=True) Loading @@ -59,18 +108,19 @@ class TestTaggingWd14: 'sensitive': 0.0080718994140625, 'questionable': 0.003170192241668701, 'explicit': 0.984081506729126, }, abs=1e-3) }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438, 'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896, 'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697, 'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648, 'collarbone': 0.588348925113678, 'red_hair': 0.9200156331062317, 'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056, 'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686, 'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291, 'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901 }, abs=1e-3) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=1e-3) '1girl': 0.998561441898346, 'solo': 0.9918843507766724, 'looking_at_viewer': 0.9200493693351746, 'blush': 0.8876285552978516, 'smile': 0.5031097531318665, 'bangs': 0.4979058504104614, 'large_breasts': 0.5059964656829834, 'navel': 0.9681310653686523, 'hair_between_eyes': 0.5816333293914795, 'medium_breasts': 0.36410677433013916, 'very_long_hair': 0.811715304851532, 'closed_mouth': 0.9338403940200806, 'nipples': 0.9715133905410767, 'purple_eyes': 0.9681202173233032, 'collarbone': 0.573296308517456, 'red_hair': 0.9242303967475891, 'sweat': 0.8757796287536621, 'horns': 0.973071277141571, 'spread_legs': 0.9634276628494263, 'armpits': 0.9116500616073608, 'stomach': 0.6858262419700623, 'arms_up': 0.9398491978645325, 'completely_nude': 0.907513439655304, 'uncensored': 0.8703584671020508, 'pussy_juice': 0.6459053754806519, 'feet_out_of_frame': 0.3921701908111572, 'on_bed': 0.6049470901489258, 'arms_behind_head': 0.4758358597755432, 'breasts_apart': 0.38581883907318115, 'clitoris': 0.5746099948883057 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9942929744720459}, abs=2e-2)