Commit 2b862b6b authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add extractor

parent 080716e0
Loading
Loading
Loading
Loading
+4 −27
Original line number Diff line number Diff line
@@ -19,8 +19,9 @@ from typing import Dict, Optional, Tuple
import numpy as np
from huggingface_hub import hf_hub_download

from imgutils.data import ImageTyping
from imgutils.generic import ClassifyModel
from ..data import ImageTyping
from ..generic import ClassifyModel
from ..utils import vreplace

__all__ = [
    'anime_dbaesthetic',
@@ -40,30 +41,6 @@ _DEFAULT_LABEL_MAPPING = {
}


def _value_replace(v, mapping):
    """
    Replaces values in a data structure using a mapping dictionary.

    :param v: The input data structure.
    :type v: Any
    :param mapping: A dictionary mapping values to replacement values.
    :type mapping: Dict
    :return: The modified data structure.
    :rtype: Any
    """
    if isinstance(v, (list, tuple)):
        return type(v)([_value_replace(vitem, mapping) for vitem in v])
    elif isinstance(v, dict):
        return type(v)({key: _value_replace(value, mapping) for key, value in v.items()})
    else:
        try:
            _ = hash(v)
        except TypeError:  # pragma: no cover
            return v
        else:
            return mapping.get(v, v)


class AestheticModel:
    """
    A model for assessing the aesthetic quality of anime images.
@@ -171,7 +148,7 @@ class AestheticModel:
        score, confidence = self.get_aesthetic_score(image, model_name)
        percentile = self.score_to_percentile(score, model_name)
        label = self.percentile_to_label(percentile)
        return _value_replace(
        return vreplace(
            v=fmt,
            mapping={
                'label': label,
+22 −4
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@ from huggingface_hub import hf_hub_download
from .format import remove_underline
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model
from ..utils import open_onnx_model, vreplace

SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
@@ -64,7 +64,10 @@ def _get_wd14_model(model_name):
    :rtype: ONNXModel
    """
    _version_support_check(model_name)
    return open_onnx_model(hf_hub_download(MODEL_NAMES[model_name], MODEL_FILENAME))
    return open_onnx_model(hf_hub_download(
        repo_id='deepghs/wd14_tagger_with_embeddings',
        filename=f'{MODEL_NAMES[model_name]}/model.onnx',
    ))


@lru_cache()
@@ -133,6 +136,7 @@ def get_wd14_tags(
        character_mcut_enabled: bool = False,
        no_underline: bool = False,
        drop_overlap: bool = False,
        fmt=('rating', 'general', 'character'),
):
    """
    Overview:
@@ -155,6 +159,9 @@ def get_wd14_tags(
    :type no_underline: bool
    :param drop_overlap: If True, drops overlapping tags.
    :type drop_overlap: bool
    :param fmt: Return format, default is ``('rating', 'general', 'character')``.
        ``embedding`` is also supported for feature extraction.
    :type fmt: Any
    :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities.
    :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]

@@ -189,8 +196,10 @@ def get_wd14_tags(
    image = _prepare_image_for_tagging(image, target_size)

    input_name = model.get_inputs()[0].name
    assert len(model.get_outputs()) == 2
    label_name = model.get_outputs()[0].name
    preds = model.run([label_name], {input_name: image})[0]
    emb_name = model.get_outputs()[1].name
    preds, embeddings = model.run([label_name, emb_name], {input_name: image})
    labels = list(zip(tag_names, preds[0].astype(float)))

    ratings_names = [labels[i] for i in rating_indexes]
@@ -215,4 +224,13 @@ def get_wd14_tags(
    character_res = [x for x in character_names if x[1] > character_threshold]
    character_res = dict(character_res)

    return rating, general_res, character_res
    return vreplace(
        fmt,
        {
            'rating': rating,
            'general': general_res,
            'character': character_res,
            'tag': {**general_res, **character_res},
            'embedding': embeddings,
        }
    )
+1 −0
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ Overview:
    Generic utilities for :mod:`imgutils`.
"""
from .area import *
from .format import *
from .onnxruntime import *
from .storage import *
from .tqdm_ import *
+26 −0
Original line number Diff line number Diff line
__all__ = [
    'vreplace',
]


def vreplace(v, mapping):
    """
    Replaces values in a data structure using a mapping dictionary.
    :param v: The input data structure.
    :type v: Any
    :param mapping: A dictionary mapping values to replacement values.
    :type mapping: Dict
    :return: The modified data structure.
    :rtype: Any
    """
    if isinstance(v, (list, tuple)):
        return type(v)([vreplace(vitem, mapping) for vitem in v])
    elif isinstance(v, dict):
        return type(v)({key: vreplace(value, mapping) for key, value in v.items()})
    else:
        try:
            _ = hash(v)
        except TypeError:  # pragma: no cover
            return v
        else:
            return mapping.get(v, v)