Loading imgutils/tagging/wd14.py +52 −13 Original line number Diff line number Diff line Loading @@ -7,7 +7,8 @@ Overview: project on Hugging Face. """ from typing import List, Tuple, Any from collections import defaultdict from typing import List, Tuple, Any, Optional, Mapping, Dict, Union import numpy as np import onnxruntime Loading @@ -19,7 +20,8 @@ from huggingface_hub import hf_hub_download from .format import remove_underline from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache from ..generic.attachment import open_attachment, Attachment from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache, vnames SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" Loading Loading @@ -189,6 +191,7 @@ def _postprocess_embedding( no_underline: bool = False, drop_overlap: bool = False, fmt: Any = ('rating', 'general', 'character'), attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] = None, ): """ Post-process the embedding and prediction results. Loading @@ -214,6 +217,27 @@ def _postprocess_embedding( :param fmt: The format of the output. :return: The post-processed results. """ attachments = dict(attachments or {}) d_attachments: Dict[str, Tuple[Attachment, dict]] = {} for attach_name, attach_tpl in attachments.items(): if '/' in attach_name: raise ValueError(f'Invalid attachment register name, no \'/\' required - {attach_name!r}.') if len(attach_tpl) == 2: (attach_repo_id, attach_model_name), attach_kwargs = attach_tpl, {} elif len(attach_tpl) == 3: attach_repo_id, attach_model_name, attach_kwargs = attach_tpl else: raise ValueError(f'Invalid attachment tuple for {attach_name!r}, ' f'2 or 3 elements expected but {attach_tpl!r} found.') attachment = open_attachment(repo_id=attach_repo_id, model_name=attach_model_name) expected_encoder_model = f'wdtagger:{MODEL_NAMES[model_name]}' if attachment.encoder_model != expected_encoder_model: raise ValueError(f'Attachment encoder model not match, ' f'{expected_encoder_model!r} expected but {attachment.encoder_model!r} found ' f'for {attach_name!r}.') d_attachments[attach_name] = (attachment, attach_kwargs) assert len(pred.shape) == len(embedding.shape) == 1, \ f'Both pred and embeddings shapes should be 1-dim, ' \ f'but pred: {pred.shape!r}, embedding: {embedding.shape!r} actually found.' Loading @@ -239,9 +263,7 @@ def _postprocess_embedding( character_res = {x: v.item() for x, v in character_names if v > character_threshold} return vreplace( fmt, { mapping_values = { 'rating': rating, 'general': general_res, 'character': character_res, Loading @@ -249,7 +271,19 @@ def _postprocess_embedding( 'embedding': embedding.astype(np.float32), 'prediction': pred.astype(np.float32), } ) d_attach_infers = defaultdict(list) for vname in vnames(fmt): if '/' in vname and vname.split('/', maxsplit=1)[0] in d_attachments: attach_name, attach_fmt_name = vname.split('/', maxsplit=1) d_attach_infers[attach_name].append(attach_fmt_name) for attach_name, attach_infer_names in d_attach_infers.items(): attachment, attach_kwargs = d_attachments[attach_name] attach_infer_values = attachment.predict(embedding=embedding, fmt=attach_infer_names, **attach_kwargs) attach_mapping_names = [f'{attach_name}/{name}' for name in attach_infer_names] mapping_values.update(dict(zip(attach_mapping_names, attach_infer_values))) return vreplace(fmt, mapping_values) def get_wd14_tags( Loading @@ -262,6 +296,7 @@ def get_wd14_tags( no_underline: bool = False, drop_overlap: bool = False, fmt: Any = ('rating', 'general', 'character'), attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] = None, ): """ Get tags for an image using WD14 taggers. Loading Loading @@ -356,6 +391,7 @@ def get_wd14_tags( no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, attachments=attachments, ) Loading @@ -372,6 +408,7 @@ def convert_wd14_emb_to_prediction( no_underline: bool = False, drop_overlap: bool = False, fmt: Any = ('rating', 'general', 'character'), attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] = None, denormalize: bool = False, denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME, ): Loading Loading @@ -453,6 +490,7 @@ def convert_wd14_emb_to_prediction( no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, attachments=attachments, ) else: return [ Loading @@ -467,6 +505,7 @@ def convert_wd14_emb_to_prediction( no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, attachments=attachments, ) for pred_item, emb_item in zip(pred, emb) ] Loading test/tagging/test_wd14.py +91 −0 Original line number Diff line number Diff line import numpy as np import pytest from hbutils.testing import tmatrix from imgutils.generic.attachment import open_attachment from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction from imgutils.tagging.wd14 import _get_wd14_model, denormalize_wd14_emb, _open_denormalize_model from test.testings import get_testfile Loading @@ -13,6 +15,7 @@ def _release_model_after_run(): finally: _get_wd14_model.cache_clear() _open_denormalize_model.cache_clear() open_attachment.cache_clear() @pytest.mark.unittest Loading Loading @@ -229,3 +232,91 @@ class TestTaggingWd14: assert rating == pytest.approx(expected_rating, abs=1e-2) assert general == pytest.approx(expected_general, abs=1e-2) assert character == pytest.approx(expected_character, abs=1e-2) @pytest.mark.parametrize(*tmatrix({ ('type_', 'file'): [ ('monochrome', '6130053.jpg'), ('monochrome', '6125854(第 3 个复件).jpg'), ('monochrome', '5221834.jpg'), ('monochrome', '1951253.jpg'), ('monochrome', '4879658.jpg'), ('monochrome', '80750471_p3_master1200.jpg'), ('normal', '54566940_p0_master1200.jpg'), ('normal', '60817155_p18_master1200.jpg'), ('normal', '4945494.jpg'), ('normal', '4008375.jpg'), ('normal', '2416278.jpg'), ('normal', '842709.jpg') ], }, mode='matrix')) def test_get_wd14_tags_with_attachments(self, type_, file): filename = get_testfile('dataset', 'monochrome_danbooru', type_, file) scores, (top_label, top_score) = get_wd14_tags( filename, fmt=('monochrome/scores', 'monochrome/top'), attachments={'monochrome': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1')}, ) assert scores[type_] >= 0.5 assert top_label == type_ assert top_score >= 0.5 @pytest.mark.parametrize(*tmatrix({ ('type_', 'file'): [ ('monochrome', '6130053.jpg'), ('monochrome', '6125854(第 3 个复件).jpg'), ('monochrome', '5221834.jpg'), ('monochrome', '1951253.jpg'), ('monochrome', '4879658.jpg'), ('monochrome', '80750471_p3_master1200.jpg'), ('normal', '54566940_p0_master1200.jpg'), ('normal', '60817155_p18_master1200.jpg'), ('normal', '4945494.jpg'), ('normal', '4008375.jpg'), ('normal', '2416278.jpg'), ('normal', '842709.jpg') ], }, mode='matrix')) def test_get_wd14_tags_with_attachments_extra_cfg(self, type_, file): filename = get_testfile('dataset', 'monochrome_danbooru', type_, file) scores, (top_label, top_score) = get_wd14_tags( filename, fmt=('monochrome/scores', 'monochrome/top'), attachments={'monochrome': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1', {})}, ) assert scores[type_] >= 0.5 assert top_label == type_ assert top_score >= 0.5 @pytest.mark.parametrize(*tmatrix({ ('type_', 'file'): [ ('monochrome', '6130053.jpg'), ('normal', '54566940_p0_master1200.jpg'), ], }, mode='matrix')) def test_get_wd14_tags_with_attachments_invalid_attachment_name(self, type_, file): filename = get_testfile('dataset', 'monochrome_danbooru', type_, file) with pytest.raises(ValueError): scores, (top_label, top_score) = get_wd14_tags( filename, fmt=('monochrome/t/scores', 'monochrome/t/top'), attachments={'monochrome/t': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1')}, ) _ = top_label, top_score @pytest.mark.parametrize(*tmatrix({ ('type_', 'file'): [ ('monochrome', '6130053.jpg'), ('normal', '54566940_p0_master1200.jpg'), ], }, mode='matrix')) def test_get_wd14_tags_with_attachments_invalid_attachment_config(self, type_, file): filename = get_testfile('dataset', 'monochrome_danbooru', type_, file) with pytest.raises(ValueError): scores, (top_label, top_score) = get_wd14_tags( filename, fmt=('monochrome/scores', 'monochrome/top'), attachments={'monochrome': ('deepghs/eattach_monochrome_experiments',)}, ) _ = top_label, top_score Loading
imgutils/tagging/wd14.py +52 −13 Original line number Diff line number Diff line Loading @@ -7,7 +7,8 @@ Overview: project on Hugging Face. """ from typing import List, Tuple, Any from collections import defaultdict from typing import List, Tuple, Any, Optional, Mapping, Dict, Union import numpy as np import onnxruntime Loading @@ -19,7 +20,8 @@ from huggingface_hub import hf_hub_download from .format import remove_underline from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache from ..generic.attachment import open_attachment, Attachment from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache, vnames SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" Loading Loading @@ -189,6 +191,7 @@ def _postprocess_embedding( no_underline: bool = False, drop_overlap: bool = False, fmt: Any = ('rating', 'general', 'character'), attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] = None, ): """ Post-process the embedding and prediction results. Loading @@ -214,6 +217,27 @@ def _postprocess_embedding( :param fmt: The format of the output. :return: The post-processed results. """ attachments = dict(attachments or {}) d_attachments: Dict[str, Tuple[Attachment, dict]] = {} for attach_name, attach_tpl in attachments.items(): if '/' in attach_name: raise ValueError(f'Invalid attachment register name, no \'/\' required - {attach_name!r}.') if len(attach_tpl) == 2: (attach_repo_id, attach_model_name), attach_kwargs = attach_tpl, {} elif len(attach_tpl) == 3: attach_repo_id, attach_model_name, attach_kwargs = attach_tpl else: raise ValueError(f'Invalid attachment tuple for {attach_name!r}, ' f'2 or 3 elements expected but {attach_tpl!r} found.') attachment = open_attachment(repo_id=attach_repo_id, model_name=attach_model_name) expected_encoder_model = f'wdtagger:{MODEL_NAMES[model_name]}' if attachment.encoder_model != expected_encoder_model: raise ValueError(f'Attachment encoder model not match, ' f'{expected_encoder_model!r} expected but {attachment.encoder_model!r} found ' f'for {attach_name!r}.') d_attachments[attach_name] = (attachment, attach_kwargs) assert len(pred.shape) == len(embedding.shape) == 1, \ f'Both pred and embeddings shapes should be 1-dim, ' \ f'but pred: {pred.shape!r}, embedding: {embedding.shape!r} actually found.' Loading @@ -239,9 +263,7 @@ def _postprocess_embedding( character_res = {x: v.item() for x, v in character_names if v > character_threshold} return vreplace( fmt, { mapping_values = { 'rating': rating, 'general': general_res, 'character': character_res, Loading @@ -249,7 +271,19 @@ def _postprocess_embedding( 'embedding': embedding.astype(np.float32), 'prediction': pred.astype(np.float32), } ) d_attach_infers = defaultdict(list) for vname in vnames(fmt): if '/' in vname and vname.split('/', maxsplit=1)[0] in d_attachments: attach_name, attach_fmt_name = vname.split('/', maxsplit=1) d_attach_infers[attach_name].append(attach_fmt_name) for attach_name, attach_infer_names in d_attach_infers.items(): attachment, attach_kwargs = d_attachments[attach_name] attach_infer_values = attachment.predict(embedding=embedding, fmt=attach_infer_names, **attach_kwargs) attach_mapping_names = [f'{attach_name}/{name}' for name in attach_infer_names] mapping_values.update(dict(zip(attach_mapping_names, attach_infer_values))) return vreplace(fmt, mapping_values) def get_wd14_tags( Loading @@ -262,6 +296,7 @@ def get_wd14_tags( no_underline: bool = False, drop_overlap: bool = False, fmt: Any = ('rating', 'general', 'character'), attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] = None, ): """ Get tags for an image using WD14 taggers. Loading Loading @@ -356,6 +391,7 @@ def get_wd14_tags( no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, attachments=attachments, ) Loading @@ -372,6 +408,7 @@ def convert_wd14_emb_to_prediction( no_underline: bool = False, drop_overlap: bool = False, fmt: Any = ('rating', 'general', 'character'), attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] = None, denormalize: bool = False, denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME, ): Loading Loading @@ -453,6 +490,7 @@ def convert_wd14_emb_to_prediction( no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, attachments=attachments, ) else: return [ Loading @@ -467,6 +505,7 @@ def convert_wd14_emb_to_prediction( no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, attachments=attachments, ) for pred_item, emb_item in zip(pred, emb) ] Loading
test/tagging/test_wd14.py +91 −0 Original line number Diff line number Diff line import numpy as np import pytest from hbutils.testing import tmatrix from imgutils.generic.attachment import open_attachment from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction from imgutils.tagging.wd14 import _get_wd14_model, denormalize_wd14_emb, _open_denormalize_model from test.testings import get_testfile Loading @@ -13,6 +15,7 @@ def _release_model_after_run(): finally: _get_wd14_model.cache_clear() _open_denormalize_model.cache_clear() open_attachment.cache_clear() @pytest.mark.unittest Loading Loading @@ -229,3 +232,91 @@ class TestTaggingWd14: assert rating == pytest.approx(expected_rating, abs=1e-2) assert general == pytest.approx(expected_general, abs=1e-2) assert character == pytest.approx(expected_character, abs=1e-2) @pytest.mark.parametrize(*tmatrix({ ('type_', 'file'): [ ('monochrome', '6130053.jpg'), ('monochrome', '6125854(第 3 个复件).jpg'), ('monochrome', '5221834.jpg'), ('monochrome', '1951253.jpg'), ('monochrome', '4879658.jpg'), ('monochrome', '80750471_p3_master1200.jpg'), ('normal', '54566940_p0_master1200.jpg'), ('normal', '60817155_p18_master1200.jpg'), ('normal', '4945494.jpg'), ('normal', '4008375.jpg'), ('normal', '2416278.jpg'), ('normal', '842709.jpg') ], }, mode='matrix')) def test_get_wd14_tags_with_attachments(self, type_, file): filename = get_testfile('dataset', 'monochrome_danbooru', type_, file) scores, (top_label, top_score) = get_wd14_tags( filename, fmt=('monochrome/scores', 'monochrome/top'), attachments={'monochrome': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1')}, ) assert scores[type_] >= 0.5 assert top_label == type_ assert top_score >= 0.5 @pytest.mark.parametrize(*tmatrix({ ('type_', 'file'): [ ('monochrome', '6130053.jpg'), ('monochrome', '6125854(第 3 个复件).jpg'), ('monochrome', '5221834.jpg'), ('monochrome', '1951253.jpg'), ('monochrome', '4879658.jpg'), ('monochrome', '80750471_p3_master1200.jpg'), ('normal', '54566940_p0_master1200.jpg'), ('normal', '60817155_p18_master1200.jpg'), ('normal', '4945494.jpg'), ('normal', '4008375.jpg'), ('normal', '2416278.jpg'), ('normal', '842709.jpg') ], }, mode='matrix')) def test_get_wd14_tags_with_attachments_extra_cfg(self, type_, file): filename = get_testfile('dataset', 'monochrome_danbooru', type_, file) scores, (top_label, top_score) = get_wd14_tags( filename, fmt=('monochrome/scores', 'monochrome/top'), attachments={'monochrome': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1', {})}, ) assert scores[type_] >= 0.5 assert top_label == type_ assert top_score >= 0.5 @pytest.mark.parametrize(*tmatrix({ ('type_', 'file'): [ ('monochrome', '6130053.jpg'), ('normal', '54566940_p0_master1200.jpg'), ], }, mode='matrix')) def test_get_wd14_tags_with_attachments_invalid_attachment_name(self, type_, file): filename = get_testfile('dataset', 'monochrome_danbooru', type_, file) with pytest.raises(ValueError): scores, (top_label, top_score) = get_wd14_tags( filename, fmt=('monochrome/t/scores', 'monochrome/t/top'), attachments={'monochrome/t': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1')}, ) _ = top_label, top_score @pytest.mark.parametrize(*tmatrix({ ('type_', 'file'): [ ('monochrome', '6130053.jpg'), ('normal', '54566940_p0_master1200.jpg'), ], }, mode='matrix')) def test_get_wd14_tags_with_attachments_invalid_attachment_config(self, type_, file): filename = get_testfile('dataset', 'monochrome_danbooru', type_, file) with pytest.raises(ValueError): scores, (top_label, top_score) = get_wd14_tags( filename, fmt=('monochrome/scores', 'monochrome/top'), attachments={'monochrome': ('deepghs/eattach_monochrome_experiments',)}, ) _ = top_label, top_score