Unverified Commit 1817ce63 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #92 from deepghs/dev/tagemb

dev(narugo): add prediction && use float as tagger output
parents 1b9e2c07 5fb64ae4
Loading
Loading
Loading
Loading
+15 −7
Original line number Diff line number Diff line
@@ -165,6 +165,16 @@ def get_wd14_tags(
    :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities.
    :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]

    .. note::
        About ``fmt`` argument, these are the available names:

        * ``rating``, a dict containing ratings and their confidences
        * ``general``, a dict containing general tags and their confidences
        * ``character``, a dict containing character tags and their confidences
        * ``tag``, a dict containing all tags (including general and character, not including rating) and their confidences
        * ``embedding``, a 1-dim embedding of image, recommended for index building after L2 normalization
        * ``prediction``, a 1-dim prediction result of image

    Example:
        Here are some images for example

@@ -202,16 +212,14 @@ def get_wd14_tags(
    preds, embeddings = model.run([label_name, emb_name], {input_name: image})
    labels = list(zip(tag_names, preds[0].astype(float)))

    ratings_names = [labels[i] for i in rating_indexes]
    rating = dict(ratings_names)
    rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes}

    general_names = [labels[i] for i in general_indexes]
    if general_mcut_enabled:
        general_probs = np.array([x[1] for x in general_names])
        general_threshold = _mcut_threshold(general_probs)

    general_res = [x for x in general_names if x[1] > general_threshold]
    general_res = dict(general_res)
    general_res = {x: v.item() for x, v in general_names if v > general_threshold}
    if drop_overlap:
        general_res = drop_overlap_tags(general_res)

@@ -221,8 +229,7 @@ def get_wd14_tags(
        character_threshold = _mcut_threshold(character_probs)
        character_threshold = max(0.15, character_threshold)

    character_res = [x for x in character_names if x[1] > character_threshold]
    character_res = dict(character_res)
    character_res = {x: v.item() for x, v in character_names if v > character_threshold}

    return vreplace(
        fmt,
@@ -231,6 +238,7 @@ def get_wd14_tags(
            'general': general_res,
            'character': character_res,
            'tag': {**general_res, **character_res},
            'embedding': embeddings[0],
            'embedding': embeddings[0].astype(np.float32),
            'prediction': preds[0].astype(np.float32),
        }
    )
+5 −0
Original line number Diff line number Diff line
@@ -21,11 +21,16 @@ class TestTaggingWd14:
        assert rating['general'] > 0.9
        assert tags['cat_girl'] >= 0.8
        assert not chars
        assert isinstance(rating['general'], float)
        assert isinstance(tags['cat_girl'], float)

        rating, tags, chars = get_wd14_tags(get_testfile('6125785.jpg'))
        assert 0.6 <= rating['general'] <= 0.8
        assert tags['1girl'] >= 0.95
        assert chars['hu_tao_(genshin_impact)'] >= 0.95
        assert isinstance(rating['general'], float)
        assert isinstance(tags['1girl'], float)
        assert isinstance(chars['hu_tao_(genshin_impact)'], float)

    def test_wd14_tags_sample(self):
        rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'))