Commit 27ea19e7 authored by narugo1992's avatar narugo1992
Browse files

Merge branch 'main' into dev/head

parents 593d20c7 74d4a2d8
Loading
Loading
Loading
Loading
+9 −2
Original line number Diff line number Diff line
@@ -114,14 +114,21 @@ def _mcut_threshold(probs) -> float:
    return thresh


def _has_alpha_channel(image: Image.Image) -> bool:
    return any(band in {'A', 'a', 'P'} for band in image.getbands())


def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
    image = load_image(image, force_background='white', mode='RGB')
    image = load_image(image, force_background=None, mode=None)
    image_shape = image.size
    max_dim = max(image_shape)
    pad_left = (max_dim - image_shape[0]) // 2
    pad_top = (max_dim - image_shape[1]) // 2

    padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
    if _has_alpha_channel(image):
        padded_image.paste(image, (pad_left, pad_top), mask=image)
    else:
        padded_image.paste(image, (pad_left, pad_top))

    if max_dim != target_size:
+29 −0
Original line number Diff line number Diff line
@@ -130,3 +130,32 @@ class TestTaggingWd14:
            'breasts_apart': 0.35740798711776733, 'clitoris': 0.44502270221710205
        }, abs=2e-2)
        assert chars == pytest.approx({'surtr_(arknights)': 0.9957615733146667}, abs=2e-2)

    def test_wd14_rgba(self):
        rating, tags, chars = get_wd14_tags(get_testfile('nian.png'))
        assert rating == pytest.approx({
            'general': 0.013875722885131836, 'sensitive': 0.9790834188461304,
            'questionable': 0.0004328787326812744, 'explicit': 0.00010639429092407227,
        }, abs=2e-2)
        assert tags == pytest.approx({
            '1girl': 0.996912956237793, 'solo': 0.9690700769424438, 'long_hair': 0.9183608293533325,
            'breasts': 0.5793432593345642, 'looking_at_viewer': 0.9029998779296875, 'smile': 0.7181373834609985,
            'open_mouth': 0.5431916117668152, 'simple_background': 0.3519788384437561,
            'long_sleeves': 0.7442969679832458, 'white_background': 0.6004813313484192, 'holding': 0.7325218319892883,
            'navel': 0.9297535419464111, 'jewelry': 0.5435991287231445, 'standing': 0.8762419819831848,
            'purple_eyes': 0.9269286394119263, 'tail': 0.8547350168228149, 'full_body': 0.9316157102584839,
            'white_hair': 0.9207442402839661, 'braid': 0.37353646755218506, 'multicolored_hair': 0.6516135931015015,
            'thighs': 0.451822429895401, ':d': 0.5130974054336548, 'red_hair': 0.5783762335777283,
            'small_breasts': 0.3563075065612793, 'boots': 0.6243380308151245, 'open_clothes': 0.8822896480560303,
            'horns': 0.965097188949585, 'shorts': 0.9586330056190491, 'shoes': 0.4847032427787781,
            'socks': 0.47281092405319214, 'tongue': 0.9029147624969482, 'pointy_ears': 0.8633939623832703,
            'belt': 0.4783763289451599, 'midriff': 0.9044876098632812, 'tongue_out': 0.9018264412879944,
            'wide_sleeves': 0.7076666951179504, 'stomach': 0.891795814037323, 'streaked_hair': 0.6510426998138428,
            'coat': 0.7965987324714661, 'crop_top': 0.6840215921401978, 'hand_on_own_hip': 0.5604047179222107,
            'strapless': 0.950110137462616, 'short_shorts': 0.6481347680091858, 'bare_legs': 0.5356456637382507,
            'white_footwear': 0.8399633169174194, 'transparent_background': 0.3643641471862793, ':p': 0.532076358795166,
            'half_updo': 0.5155724883079529, 'open_coat': 0.8147380352020264, 'beads': 0.3977043032646179,
            'white_shorts': 0.9007017612457275, 'white_coat': 0.8003122806549072, 'bandeau': 0.9671074151992798,
            'tube_top': 0.9783295392990112, 'bead_bracelet': 0.3510066270828247, 'red_bandeau': 0.8741766214370728
        }, abs=2e-2)
        assert chars == pytest.approx({'nian_(arknights)': 0.9968841671943665}, abs=2e-2)