Loading imgutils/tagging/wd14.py +9 −2 Original line number Diff line number Diff line Loading @@ -114,14 +114,21 @@ def _mcut_threshold(probs) -> float: return thresh def _has_alpha_channel(image: Image.Image) -> bool: return any(band in {'A', 'a', 'P'} for band in image.getbands()) def _prepare_image_for_tagging(image: ImageTyping, target_size: int): image = load_image(image, force_background='white', mode='RGB') image = load_image(image, force_background=None, mode=None) image_shape = image.size max_dim = max(image_shape) pad_left = (max_dim - image_shape[0]) // 2 pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) if _has_alpha_channel(image): padded_image.paste(image, (pad_left, pad_top), mask=image) else: padded_image.paste(image, (pad_left, pad_top)) if max_dim != target_size: Loading test/tagging/test_wd14.py +29 −0 Original line number Diff line number Diff line Loading @@ -130,3 +130,32 @@ class TestTaggingWd14: 'breasts_apart': 0.35740798711776733, 'clitoris': 0.44502270221710205 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9957615733146667}, abs=2e-2) def test_wd14_rgba(self): rating, tags, chars = get_wd14_tags(get_testfile('nian.png')) assert rating == pytest.approx({ 'general': 0.013875722885131836, 'sensitive': 0.9790834188461304, 'questionable': 0.0004328787326812744, 'explicit': 0.00010639429092407227, }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.996912956237793, 'solo': 0.9690700769424438, 'long_hair': 0.9183608293533325, 'breasts': 0.5793432593345642, 'looking_at_viewer': 0.9029998779296875, 'smile': 0.7181373834609985, 'open_mouth': 0.5431916117668152, 'simple_background': 0.3519788384437561, 'long_sleeves': 0.7442969679832458, 'white_background': 0.6004813313484192, 'holding': 0.7325218319892883, 'navel': 0.9297535419464111, 'jewelry': 0.5435991287231445, 'standing': 0.8762419819831848, 'purple_eyes': 0.9269286394119263, 'tail': 0.8547350168228149, 'full_body': 0.9316157102584839, 'white_hair': 0.9207442402839661, 'braid': 0.37353646755218506, 'multicolored_hair': 0.6516135931015015, 'thighs': 0.451822429895401, ':d': 0.5130974054336548, 'red_hair': 0.5783762335777283, 'small_breasts': 0.3563075065612793, 'boots': 0.6243380308151245, 'open_clothes': 0.8822896480560303, 'horns': 0.965097188949585, 'shorts': 0.9586330056190491, 'shoes': 0.4847032427787781, 'socks': 0.47281092405319214, 'tongue': 0.9029147624969482, 'pointy_ears': 0.8633939623832703, 'belt': 0.4783763289451599, 'midriff': 0.9044876098632812, 'tongue_out': 0.9018264412879944, 'wide_sleeves': 0.7076666951179504, 'stomach': 0.891795814037323, 'streaked_hair': 0.6510426998138428, 'coat': 0.7965987324714661, 'crop_top': 0.6840215921401978, 'hand_on_own_hip': 0.5604047179222107, 'strapless': 0.950110137462616, 'short_shorts': 0.6481347680091858, 'bare_legs': 0.5356456637382507, 'white_footwear': 0.8399633169174194, 'transparent_background': 0.3643641471862793, ':p': 0.532076358795166, 'half_updo': 0.5155724883079529, 'open_coat': 0.8147380352020264, 'beads': 0.3977043032646179, 'white_shorts': 0.9007017612457275, 'white_coat': 0.8003122806549072, 'bandeau': 0.9671074151992798, 'tube_top': 0.9783295392990112, 'bead_bracelet': 0.3510066270828247, 'red_bandeau': 0.8741766214370728 }, abs=2e-2) assert chars == pytest.approx({'nian_(arknights)': 0.9968841671943665}, abs=2e-2) Loading
imgutils/tagging/wd14.py +9 −2 Original line number Diff line number Diff line Loading @@ -114,14 +114,21 @@ def _mcut_threshold(probs) -> float: return thresh def _has_alpha_channel(image: Image.Image) -> bool: return any(band in {'A', 'a', 'P'} for band in image.getbands()) def _prepare_image_for_tagging(image: ImageTyping, target_size: int): image = load_image(image, force_background='white', mode='RGB') image = load_image(image, force_background=None, mode=None) image_shape = image.size max_dim = max(image_shape) pad_left = (max_dim - image_shape[0]) // 2 pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) if _has_alpha_channel(image): padded_image.paste(image, (pad_left, pad_top), mask=image) else: padded_image.paste(image, (pad_left, pad_top)) if max_dim != target_size: Loading
test/tagging/test_wd14.py +29 −0 Original line number Diff line number Diff line Loading @@ -130,3 +130,32 @@ class TestTaggingWd14: 'breasts_apart': 0.35740798711776733, 'clitoris': 0.44502270221710205 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9957615733146667}, abs=2e-2) def test_wd14_rgba(self): rating, tags, chars = get_wd14_tags(get_testfile('nian.png')) assert rating == pytest.approx({ 'general': 0.013875722885131836, 'sensitive': 0.9790834188461304, 'questionable': 0.0004328787326812744, 'explicit': 0.00010639429092407227, }, abs=2e-2) assert tags == pytest.approx({ '1girl': 0.996912956237793, 'solo': 0.9690700769424438, 'long_hair': 0.9183608293533325, 'breasts': 0.5793432593345642, 'looking_at_viewer': 0.9029998779296875, 'smile': 0.7181373834609985, 'open_mouth': 0.5431916117668152, 'simple_background': 0.3519788384437561, 'long_sleeves': 0.7442969679832458, 'white_background': 0.6004813313484192, 'holding': 0.7325218319892883, 'navel': 0.9297535419464111, 'jewelry': 0.5435991287231445, 'standing': 0.8762419819831848, 'purple_eyes': 0.9269286394119263, 'tail': 0.8547350168228149, 'full_body': 0.9316157102584839, 'white_hair': 0.9207442402839661, 'braid': 0.37353646755218506, 'multicolored_hair': 0.6516135931015015, 'thighs': 0.451822429895401, ':d': 0.5130974054336548, 'red_hair': 0.5783762335777283, 'small_breasts': 0.3563075065612793, 'boots': 0.6243380308151245, 'open_clothes': 0.8822896480560303, 'horns': 0.965097188949585, 'shorts': 0.9586330056190491, 'shoes': 0.4847032427787781, 'socks': 0.47281092405319214, 'tongue': 0.9029147624969482, 'pointy_ears': 0.8633939623832703, 'belt': 0.4783763289451599, 'midriff': 0.9044876098632812, 'tongue_out': 0.9018264412879944, 'wide_sleeves': 0.7076666951179504, 'stomach': 0.891795814037323, 'streaked_hair': 0.6510426998138428, 'coat': 0.7965987324714661, 'crop_top': 0.6840215921401978, 'hand_on_own_hip': 0.5604047179222107, 'strapless': 0.950110137462616, 'short_shorts': 0.6481347680091858, 'bare_legs': 0.5356456637382507, 'white_footwear': 0.8399633169174194, 'transparent_background': 0.3643641471862793, ':p': 0.532076358795166, 'half_updo': 0.5155724883079529, 'open_coat': 0.8147380352020264, 'beads': 0.3977043032646179, 'white_shorts': 0.9007017612457275, 'white_coat': 0.8003122806549072, 'bandeau': 0.9671074151992798, 'tube_top': 0.9783295392990112, 'bead_bracelet': 0.3510066270828247, 'red_bandeau': 0.8741766214370728 }, abs=2e-2) assert chars == pytest.approx({'nian_(arknights)': 0.9968841671943665}, abs=2e-2)