Loading imgutils/metrics/dbaesthetic.py +4 −27 Original line number Diff line number Diff line Loading @@ -19,8 +19,9 @@ from typing import Dict, Optional, Tuple import numpy as np from huggingface_hub import hf_hub_download from imgutils.data import ImageTyping from imgutils.generic import ClassifyModel from ..data import ImageTyping from ..generic import ClassifyModel from ..utils import vreplace __all__ = [ 'anime_dbaesthetic', Loading @@ -40,30 +41,6 @@ _DEFAULT_LABEL_MAPPING = { } def _value_replace(v, mapping): """ Replaces values in a data structure using a mapping dictionary. :param v: The input data structure. :type v: Any :param mapping: A dictionary mapping values to replacement values. :type mapping: Dict :return: The modified data structure. :rtype: Any """ if isinstance(v, (list, tuple)): return type(v)([_value_replace(vitem, mapping) for vitem in v]) elif isinstance(v, dict): return type(v)({key: _value_replace(value, mapping) for key, value in v.items()}) else: try: _ = hash(v) except TypeError: # pragma: no cover return v else: return mapping.get(v, v) class AestheticModel: """ A model for assessing the aesthetic quality of anime images. Loading Loading @@ -171,7 +148,7 @@ class AestheticModel: score, confidence = self.get_aesthetic_score(image, model_name) percentile = self.score_to_percentile(score, model_name) label = self.percentile_to_label(percentile) return _value_replace( return vreplace( v=fmt, mapping={ 'label': label, Loading imgutils/tagging/wd14.py +22 −4 Original line number Diff line number Diff line Loading @@ -16,7 +16,7 @@ from huggingface_hub import hf_hub_download from .format import remove_underline from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping from ..utils import open_onnx_model from ..utils import open_onnx_model, vreplace SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" Loading Loading @@ -64,7 +64,10 @@ def _get_wd14_model(model_name): :rtype: ONNXModel """ _version_support_check(model_name) return open_onnx_model(hf_hub_download(MODEL_NAMES[model_name], MODEL_FILENAME)) return open_onnx_model(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', filename=f'{MODEL_NAMES[model_name]}/model.onnx', )) @lru_cache() Loading Loading @@ -133,6 +136,7 @@ def get_wd14_tags( character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): """ Overview: Loading @@ -155,6 +159,9 @@ def get_wd14_tags( :type no_underline: bool :param drop_overlap: If True, drops overlapping tags. :type drop_overlap: bool :param fmt: Return format, default is ``('rating', 'general', 'character')``. ``embedding`` is also supported for feature extraction. :type fmt: Any :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] Loading Loading @@ -189,8 +196,10 @@ def get_wd14_tags( image = _prepare_image_for_tagging(image, target_size) input_name = model.get_inputs()[0].name assert len(model.get_outputs()) == 2 label_name = model.get_outputs()[0].name preds = model.run([label_name], {input_name: image})[0] emb_name = model.get_outputs()[1].name preds, embeddings = model.run([label_name, emb_name], {input_name: image}) labels = list(zip(tag_names, preds[0].astype(float))) ratings_names = [labels[i] for i in rating_indexes] Loading @@ -215,4 +224,13 @@ def get_wd14_tags( character_res = [x for x in character_names if x[1] > character_threshold] character_res = dict(character_res) return rating, general_res, character_res return vreplace( fmt, { 'rating': rating, 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, 'embedding': embeddings, } ) imgutils/utils/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -3,6 +3,7 @@ Overview: Generic utilities for :mod:`imgutils`. """ from .area import * from .format import * from .onnxruntime import * from .storage import * from .tqdm_ import * imgutils/utils/format.py 0 → 100644 +26 −0 Original line number Diff line number Diff line __all__ = [ 'vreplace', ] def vreplace(v, mapping): """ Replaces values in a data structure using a mapping dictionary. :param v: The input data structure. :type v: Any :param mapping: A dictionary mapping values to replacement values. :type mapping: Dict :return: The modified data structure. :rtype: Any """ if isinstance(v, (list, tuple)): return type(v)([vreplace(vitem, mapping) for vitem in v]) elif isinstance(v, dict): return type(v)({key: vreplace(value, mapping) for key, value in v.items()}) else: try: _ = hash(v) except TypeError: # pragma: no cover return v else: return mapping.get(v, v) Loading
imgutils/metrics/dbaesthetic.py +4 −27 Original line number Diff line number Diff line Loading @@ -19,8 +19,9 @@ from typing import Dict, Optional, Tuple import numpy as np from huggingface_hub import hf_hub_download from imgutils.data import ImageTyping from imgutils.generic import ClassifyModel from ..data import ImageTyping from ..generic import ClassifyModel from ..utils import vreplace __all__ = [ 'anime_dbaesthetic', Loading @@ -40,30 +41,6 @@ _DEFAULT_LABEL_MAPPING = { } def _value_replace(v, mapping): """ Replaces values in a data structure using a mapping dictionary. :param v: The input data structure. :type v: Any :param mapping: A dictionary mapping values to replacement values. :type mapping: Dict :return: The modified data structure. :rtype: Any """ if isinstance(v, (list, tuple)): return type(v)([_value_replace(vitem, mapping) for vitem in v]) elif isinstance(v, dict): return type(v)({key: _value_replace(value, mapping) for key, value in v.items()}) else: try: _ = hash(v) except TypeError: # pragma: no cover return v else: return mapping.get(v, v) class AestheticModel: """ A model for assessing the aesthetic quality of anime images. Loading Loading @@ -171,7 +148,7 @@ class AestheticModel: score, confidence = self.get_aesthetic_score(image, model_name) percentile = self.score_to_percentile(score, model_name) label = self.percentile_to_label(percentile) return _value_replace( return vreplace( v=fmt, mapping={ 'label': label, Loading
imgutils/tagging/wd14.py +22 −4 Original line number Diff line number Diff line Loading @@ -16,7 +16,7 @@ from huggingface_hub import hf_hub_download from .format import remove_underline from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping from ..utils import open_onnx_model from ..utils import open_onnx_model, vreplace SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" Loading Loading @@ -64,7 +64,10 @@ def _get_wd14_model(model_name): :rtype: ONNXModel """ _version_support_check(model_name) return open_onnx_model(hf_hub_download(MODEL_NAMES[model_name], MODEL_FILENAME)) return open_onnx_model(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', filename=f'{MODEL_NAMES[model_name]}/model.onnx', )) @lru_cache() Loading Loading @@ -133,6 +136,7 @@ def get_wd14_tags( character_mcut_enabled: bool = False, no_underline: bool = False, drop_overlap: bool = False, fmt=('rating', 'general', 'character'), ): """ Overview: Loading @@ -155,6 +159,9 @@ def get_wd14_tags( :type no_underline: bool :param drop_overlap: If True, drops overlapping tags. :type drop_overlap: bool :param fmt: Return format, default is ``('rating', 'general', 'character')``. ``embedding`` is also supported for feature extraction. :type fmt: Any :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities. :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]] Loading Loading @@ -189,8 +196,10 @@ def get_wd14_tags( image = _prepare_image_for_tagging(image, target_size) input_name = model.get_inputs()[0].name assert len(model.get_outputs()) == 2 label_name = model.get_outputs()[0].name preds = model.run([label_name], {input_name: image})[0] emb_name = model.get_outputs()[1].name preds, embeddings = model.run([label_name, emb_name], {input_name: image}) labels = list(zip(tag_names, preds[0].astype(float))) ratings_names = [labels[i] for i in rating_indexes] Loading @@ -215,4 +224,13 @@ def get_wd14_tags( character_res = [x for x in character_names if x[1] > character_threshold] character_res = dict(character_res) return rating, general_res, character_res return vreplace( fmt, { 'rating': rating, 'general': general_res, 'character': character_res, 'tag': {**general_res, **character_res}, 'embedding': embeddings, } )
imgutils/utils/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -3,6 +3,7 @@ Overview: Generic utilities for :mod:`imgutils`. """ from .area import * from .format import * from .onnxruntime import * from .storage import * from .tqdm_ import *
imgutils/utils/format.py 0 → 100644 +26 −0 Original line number Diff line number Diff line __all__ = [ 'vreplace', ] def vreplace(v, mapping): """ Replaces values in a data structure using a mapping dictionary. :param v: The input data structure. :type v: Any :param mapping: A dictionary mapping values to replacement values. :type mapping: Dict :return: The modified data structure. :rtype: Any """ if isinstance(v, (list, tuple)): return type(v)([vreplace(vitem, mapping) for vitem in v]) elif isinstance(v, dict): return type(v)({key: vreplace(value, mapping) for key, value in v.items()}) else: try: _ = hash(v) except TypeError: # pragma: no cover return v else: return mapping.get(v, v)