Loading imgutils/tagging/wd14.py +15 −7 Original line number Diff line number Diff line Loading @@ -165,6 +165,16 @@ def get_wd14_tags( :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] .. note:: About ``fmt`` argument, these are the available names: * ``rating``, a dict containing ratings and their confidences * ``general``, a dict containing general tags and their confidences * ``character``, a dict containing character tags and their confidences * ``tag``, a dict containing all tags (including general and character, not including rating) and their confidences * ``embedding``, a 1-dim embedding of image, recommended for index building after L2 normalization * ``prediction``, a 1-dim prediction result of image Example: Here are some images for example Loading Loading @@ -202,16 +212,14 @@ def get_wd14_tags( preds, embeddings = model.run([label_name, emb_name], {input_name: image}) labels = list(zip(tag_names, preds[0].astype(float))) ratings_names = [labels[i] for i in rating_indexes] rating = dict(ratings_names) rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = [x for x in general_names if x[1] > general_threshold] general_res = dict(general_res) general_res = {x: v.item() for x, v in general_names if v > general_threshold} if drop_overlap: general_res = drop_overlap_tags(general_res) Loading @@ -221,8 +229,7 @@ def get_wd14_tags( character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) character_res = [x for x in character_names if x[1] > character_threshold] character_res = dict(character_res) character_res = {x: v.item() for x, v in character_names if v > character_threshold} return vreplace( fmt, Loading @@ -231,6 +238,7 @@ def get_wd14_tags( 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, 'embedding': embeddings[0], 'embedding': embeddings[0].astype(np.float32), 'prediction': preds[0].astype(np.float32), } ) test/tagging/test_wd14.py +5 −0 Original line number Diff line number Diff line Loading @@ -21,11 +21,16 @@ class TestTaggingWd14: assert rating['general'] > 0.9 assert tags['cat_girl'] >= 0.8 assert not chars assert isinstance(rating['general'], float) assert isinstance(tags['cat_girl'], float) rating, tags, chars = get_wd14_tags(get_testfile('6125785.jpg')) assert 0.6 <= rating['general'] <= 0.8 assert tags['1girl'] >= 0.95 assert chars['hu_tao_(genshin_impact)'] >= 0.95 assert isinstance(rating['general'], float) assert isinstance(tags['1girl'], float) assert isinstance(chars['hu_tao_(genshin_impact)'], float) def test_wd14_tags_sample(self): rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png')) Loading Loading
imgutils/tagging/wd14.py +15 −7 Original line number Diff line number Diff line Loading @@ -165,6 +165,16 @@ def get_wd14_tags( :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] .. note:: About ``fmt`` argument, these are the available names: * ``rating``, a dict containing ratings and their confidences * ``general``, a dict containing general tags and their confidences * ``character``, a dict containing character tags and their confidences * ``tag``, a dict containing all tags (including general and character, not including rating) and their confidences * ``embedding``, a 1-dim embedding of image, recommended for index building after L2 normalization * ``prediction``, a 1-dim prediction result of image Example: Here are some images for example Loading Loading @@ -202,16 +212,14 @@ def get_wd14_tags( preds, embeddings = model.run([label_name, emb_name], {input_name: image}) labels = list(zip(tag_names, preds[0].astype(float))) ratings_names = [labels[i] for i in rating_indexes] rating = dict(ratings_names) rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = [x for x in general_names if x[1] > general_threshold] general_res = dict(general_res) general_res = {x: v.item() for x, v in general_names if v > general_threshold} if drop_overlap: general_res = drop_overlap_tags(general_res) Loading @@ -221,8 +229,7 @@ def get_wd14_tags( character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) character_res = [x for x in character_names if x[1] > character_threshold] character_res = dict(character_res) character_res = {x: v.item() for x, v in character_names if v > character_threshold} return vreplace( fmt, Loading @@ -231,6 +238,7 @@ def get_wd14_tags( 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, 'embedding': embeddings[0], 'embedding': embeddings[0].astype(np.float32), 'prediction': preds[0].astype(np.float32), } )
test/tagging/test_wd14.py +5 −0 Original line number Diff line number Diff line Loading @@ -21,11 +21,16 @@ class TestTaggingWd14: assert rating['general'] > 0.9 assert tags['cat_girl'] >= 0.8 assert not chars assert isinstance(rating['general'], float) assert isinstance(tags['cat_girl'], float) rating, tags, chars = get_wd14_tags(get_testfile('6125785.jpg')) assert 0.6 <= rating['general'] <= 0.8 assert tags['1girl'] >= 0.95 assert chars['hu_tao_(genshin_impact)'] >= 0.95 assert isinstance(rating['general'], float) assert isinstance(tags['1girl'], float) assert isinstance(chars['hu_tao_(genshin_impact)'], float) def test_wd14_tags_sample(self): rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png')) Loading