Loading imgutils/tagging/wd14.py +98 −34 Original line number Diff line number Diff line Loading @@ -4,7 +4,7 @@ Overview: `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . """ from functools import lru_cache from typing import List, Tuple, Dict from typing import List, Tuple import numpy as np import onnxruntime Loading @@ -15,8 +15,8 @@ from huggingface_hub import hf_hub_download from .format import remove_underline from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping, has_alpha_channel from ..utils import open_onnx_model, vreplace from ..data import load_image, ImageTyping from ..utils import open_onnx_model, vreplace, sigmoid SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" Loading Loading @@ -75,6 +75,15 @@ def _get_wd14_model(model_name): )) @lru_cache() def _get_wd14_weights(model_name): _version_support_check(model_name) return np.load(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', filename=f'{MODEL_NAMES[model_name]}/inv.npz', )) @lru_cache() def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]: """ Loading Loading @@ -135,6 +144,52 @@ def _prepare_image_for_tagging(image: ImageTyping, target_size: int): return np.expand_dims(image_array, axis=0) def _postprocess_embedding( pred, embedding, model_name: str = _DEFAULT_MODEL_NAME, general_threshold: float = 0.35, general_mcut_enabled: bool = False, character_threshold: float = 0.85, character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) labels = list(zip(tag_names, pred.astype(float))) rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = {x: v.item() for x, v in general_names if v > general_threshold} if drop_overlap: general_res = drop_overlap_tags(general_res) character_names = [labels[i] for i in character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) character_res = {x: v.item() for x, v in character_names if v > character_threshold} return vreplace( fmt, { 'rating': rating, 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, 'embedding': embedding.astype(np.float32), 'prediction': pred.astype(np.float32), } ) def get_wd14_tags( image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, Loading Loading @@ -170,8 +225,8 @@ def get_wd14_tags( :param fmt: Return format, default is ``('rating', 'general', 'character')``. ``embedding`` is also supported for feature extraction. :type fmt: Any :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] :return: Prediction result based on the provided ``fmt``. In the default case, it should be a tuple of ``rating``, ``general`` and ``character``. .. note:: About ``fmt`` argument, these are the available names: Loading Loading @@ -208,7 +263,7 @@ def get_wd14_tags( >>> chars {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541} """ tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) model = _get_wd14_model(model_name) _, target_size, _, _ = model.get_inputs()[0].shape image = _prepare_image_for_tagging(image, target_size) Loading @@ -218,35 +273,44 @@ def get_wd14_tags( label_name = model.get_outputs()[0].name emb_name = model.get_outputs()[1].name preds, embeddings = model.run([label_name, emb_name], {input_name: image}) labels = list(zip(tag_names, preds[0].astype(float))) rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = {x: v.item() for x, v in general_names if v > general_threshold} if drop_overlap: general_res = drop_overlap_tags(general_res) character_names = [labels[i] for i in character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) return _postprocess_embedding( pred=preds[0], embedding=embeddings[0], model_name=model_name, general_threshold=general_threshold, general_mcut_enabled=general_mcut_enabled, character_threshold=character_threshold, character_mcut_enabled=character_mcut_enabled, no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, ) character_res = {x: v.item() for x, v in character_names if v > character_threshold} return vreplace( fmt, { 'rating': rating, 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, 'embedding': embeddings[0].astype(np.float32), 'prediction': preds[0].astype(np.float32), } def convert_wd14_emb_to_prediction( emb: np.ndarray, model_name: str = _DEFAULT_MODEL_NAME, general_threshold: float = 0.35, general_mcut_enabled: bool = False, character_threshold: float = 0.85, character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): z_weights = _get_wd14_weights(model_name) weights, bias = z_weights['weights'], z_weights['bias'] pred = sigmoid(emb @ weights + bias) return _postprocess_embedding( pred=pred, embedding=emb, model_name=model_name, general_threshold=general_threshold, general_mcut_enabled=general_mcut_enabled, character_threshold=character_threshold, character_mcut_enabled=character_mcut_enabled, no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, ) imgutils/utils/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ Overview: """ from .area import * from .format import * from .func import * from .onnxruntime import * from .storage import * from .tqdm_ import * imgutils/utils/func.py 0 → 100644 +7 −0 Original line number Diff line number Diff line import numpy as np __all__ = ['sigmoid'] def sigmoid(x): return 1 / (1 + np.exp(-x)) Loading
imgutils/tagging/wd14.py +98 −34 Original line number Diff line number Diff line Loading @@ -4,7 +4,7 @@ Overview: `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . """ from functools import lru_cache from typing import List, Tuple, Dict from typing import List, Tuple import numpy as np import onnxruntime Loading @@ -15,8 +15,8 @@ from huggingface_hub import hf_hub_download from .format import remove_underline from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping, has_alpha_channel from ..utils import open_onnx_model, vreplace from ..data import load_image, ImageTyping from ..utils import open_onnx_model, vreplace, sigmoid SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" Loading Loading @@ -75,6 +75,15 @@ def _get_wd14_model(model_name): )) @lru_cache() def _get_wd14_weights(model_name): _version_support_check(model_name) return np.load(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', filename=f'{MODEL_NAMES[model_name]}/inv.npz', )) @lru_cache() def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]: """ Loading Loading @@ -135,6 +144,52 @@ def _prepare_image_for_tagging(image: ImageTyping, target_size: int): return np.expand_dims(image_array, axis=0) def _postprocess_embedding( pred, embedding, model_name: str = _DEFAULT_MODEL_NAME, general_threshold: float = 0.35, general_mcut_enabled: bool = False, character_threshold: float = 0.85, character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) labels = list(zip(tag_names, pred.astype(float))) rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = {x: v.item() for x, v in general_names if v > general_threshold} if drop_overlap: general_res = drop_overlap_tags(general_res) character_names = [labels[i] for i in character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) character_res = {x: v.item() for x, v in character_names if v > character_threshold} return vreplace( fmt, { 'rating': rating, 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, 'embedding': embedding.astype(np.float32), 'prediction': pred.astype(np.float32), } ) def get_wd14_tags( image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, Loading Loading @@ -170,8 +225,8 @@ def get_wd14_tags( :param fmt: Return format, default is ``('rating', 'general', 'character')``. ``embedding`` is also supported for feature extraction. :type fmt: Any :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] :return: Prediction result based on the provided ``fmt``. In the default case, it should be a tuple of ``rating``, ``general`` and ``character``. .. note:: About ``fmt`` argument, these are the available names: Loading Loading @@ -208,7 +263,7 @@ def get_wd14_tags( >>> chars {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541} """ tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline) model = _get_wd14_model(model_name) _, target_size, _, _ = model.get_inputs()[0].shape image = _prepare_image_for_tagging(image, target_size) Loading @@ -218,35 +273,44 @@ def get_wd14_tags( label_name = model.get_outputs()[0].name emb_name = model.get_outputs()[1].name preds, embeddings = model.run([label_name, emb_name], {input_name: image}) labels = list(zip(tag_names, preds[0].astype(float))) rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes} general_names = [labels[i] for i in general_indexes] if general_mcut_enabled: general_probs = np.array([x[1] for x in general_names]) general_threshold = _mcut_threshold(general_probs) general_res = {x: v.item() for x, v in general_names if v > general_threshold} if drop_overlap: general_res = drop_overlap_tags(general_res) character_names = [labels[i] for i in character_indexes] if character_mcut_enabled: character_probs = np.array([x[1] for x in character_names]) character_threshold = _mcut_threshold(character_probs) character_threshold = max(0.15, character_threshold) return _postprocess_embedding( pred=preds[0], embedding=embeddings[0], model_name=model_name, general_threshold=general_threshold, general_mcut_enabled=general_mcut_enabled, character_threshold=character_threshold, character_mcut_enabled=character_mcut_enabled, no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, ) character_res = {x: v.item() for x, v in character_names if v > character_threshold} return vreplace( fmt, { 'rating': rating, 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, 'embedding': embeddings[0].astype(np.float32), 'prediction': preds[0].astype(np.float32), } def convert_wd14_emb_to_prediction( emb: np.ndarray, model_name: str = _DEFAULT_MODEL_NAME, general_threshold: float = 0.35, general_mcut_enabled: bool = False, character_threshold: float = 0.85, character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): z_weights = _get_wd14_weights(model_name) weights, bias = z_weights['weights'], z_weights['bias'] pred = sigmoid(emb @ weights + bias) return _postprocess_embedding( pred=pred, embedding=emb, model_name=model_name, general_threshold=general_threshold, general_mcut_enabled=general_mcut_enabled, character_threshold=character_threshold, character_mcut_enabled=character_mcut_enabled, no_underline=no_underline, drop_overlap=drop_overlap, fmt=fmt, )
imgutils/utils/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ Overview: """ from .area import * from .format import * from .func import * from .onnxruntime import * from .storage import * from .tqdm_ import *
imgutils/utils/func.py 0 → 100644 +7 −0 Original line number Diff line number Diff line import numpy as np __all__ = ['sigmoid'] def sigmoid(x): return 1 / (1 + np.exp(-x))