Unverified Commit d43dcae5 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #126 from deepghs/dev/denormalize

dev(narugo): add de-normalizers for the embeddings of the wd14 taggers
parents 510b6cc3 256758af
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -20,3 +20,9 @@ convert_wd14_emb_to_prediction



denormalize_wd14_emb
----------------------------------------------

.. autofunction:: denormalize_wd14_emb

+1 −1
Original line number Diff line number Diff line
@@ -16,4 +16,4 @@ from .match import tag_match_suffix, tag_match_prefix, tag_match_full
from .mldanbooru import get_mldanbooru_tags
from .order import sort_tags
from .overlap import drop_overlap_tags
from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction
from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction, denormalize_wd14_emb
+147 −22
Original line number Diff line number Diff line
@@ -214,6 +214,9 @@ def _postprocess_embedding(
    :param fmt: The format of the output.
    :return: The post-processed results.
    """
    assert len(pred.shape) == len(embedding.shape) == 1, \
        f'Both pred and embeddings shapes should be 1-dim, ' \
        f'but pred: {pred.shape!r}, embedding: {embedding.shape!r} actually found.'
    tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline)
    labels = list(zip(tag_names, pred.astype(float)))

@@ -356,6 +359,9 @@ def get_wd14_tags(
    )


_DEFAULT_DENORMALIZER_NAME = 'mnum2_all'


def convert_wd14_emb_to_prediction(
        emb: np.ndarray,
        model_name: str = _DEFAULT_MODEL_NAME,
@@ -366,46 +372,76 @@ def convert_wd14_emb_to_prediction(
        no_underline: bool = False,
        drop_overlap: bool = False,
        fmt=('rating', 'general', 'character'),
        denormalize: bool = False,
        denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
):
    """
    Convert WD14 embedding to understandable prediction result.
    Convert WD14 embedding to understandable prediction result. This function can process both
    single embeddings (1-dimensional array) and batches of embeddings (2-dimensional array).

    :param emb: The 1-dim extracted embedding.
    :param emb: The extracted embedding(s). Can be either a 1-dim array for single image or
                2-dim array for batch processing
    :type emb: numpy.ndarray
    :param model_name: The name of the model to use.
    :param model_name: Name of the WD14 model to use for prediction
    :type model_name: str
    :param general_threshold: The threshold for general tags.
    :param general_threshold: Confidence threshold for general tags (0.0 to 1.0)
    :type general_threshold: float
    :param general_mcut_enabled: If True, applies MCut thresholding to general tags.
    :param general_mcut_enabled: Enable MCut thresholding for general tags to improve prediction quality
    :type general_mcut_enabled: bool
    :param character_threshold: The threshold for character tags.
    :param character_threshold: Confidence threshold for character tags (0.0 to 1.0)
    :type character_threshold: float
    :param character_mcut_enabled: If True, applies MCut thresholding to character tags.
    :param character_mcut_enabled: Enable MCut thresholding for character tags to improve prediction quality
    :type character_mcut_enabled: bool
    :param no_underline: If True, replaces underscores in tag names with spaces.
    :param no_underline: Replace underscores with spaces in tag names for better readability
    :type no_underline: bool
    :param drop_overlap: If True, drops overlapping tags.
    :param drop_overlap: Remove overlapping tags to reduce redundancy
    :type drop_overlap: bool
    :param fmt: Return format, default is ``('rating', 'general', 'character')``.
    :return: Prediction result based on the provided fmt.
    :param fmt: Specify return format structure for predictions, default is ``('rating', 'general', 'character')``.
    :type fmt: tuple
    :param denormalize: Whether to denormalize the embedding before prediction
    :type denormalize: bool
    :param denormalizer_name: Name of the denormalizer to use if denormalization is enabled
    :type denormalizer_name: str
    :return: For single embeddings: prediction result based on fmt. For batches: list of prediction results.

    .. note::
        Only the embeddings not get normalized can be converted to understandable prediction result.
        If normalized embeddings are provided, set ``denormalize=True`` to convert them back.

    For batch processing (2-dim input), returns a list where each element corresponds
    to one embedding's predictions in the same format as single embedding output.

    Example:
        >>> import os
        >>> import numpy as np
        >>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction
        >>>
        >>> # extract the feature embedding
        >>> # extract the feature embedding, shape: (W, )
        >>> embedding = get_wd14_tags('skadi.jpg', fmt='embedding')
        >>>
        >>> # convert to understandable result
        >>> rating, general, character = convert_wd14_emb_to_prediction(embedding)
        >>> # these 3 dicts will be the same as that returned by `get_wd14_tags('skadi.jpg')`
        >>>
        >>> # Batch processing, shape: (B, W)
        >>> embeddings = np.stack([
        ...     get_wd14_tags('img1.jpg', fmt='embedding'),
        ...     get_wd14_tags('img2.jpg', fmt='embedding'),
        ... ])
        >>> # results will be a list of (rating, general, character) tuples
        >>> results = convert_wd14_emb_to_prediction(embeddings)
    """
    if denormalize:
        emb = denormalize_wd14_emb(
            emb=emb,
            model_name=model_name,
            denormalizer_name=denormalizer_name,
        )

    z_weights = _get_wd14_weights(model_name)
    weights, bias = z_weights['weights'], z_weights['bias']
    pred = sigmoid(emb @ weights + bias)
    if len(emb.shape) == 1:
        return _postprocess_embedding(
            pred=pred,
            embedding=emb,
@@ -418,3 +454,92 @@ def convert_wd14_emb_to_prediction(
            drop_overlap=drop_overlap,
            fmt=fmt,
        )
    else:
        return [
            _postprocess_embedding(
                pred=pred_item,
                embedding=emb_item,
                model_name=model_name,
                general_threshold=general_threshold,
                general_mcut_enabled=general_mcut_enabled,
                character_threshold=character_threshold,
                character_mcut_enabled=character_mcut_enabled,
                no_underline=no_underline,
                drop_overlap=drop_overlap,
                fmt=fmt,
            )
            for pred_item, emb_item in zip(pred, emb)
        ]


@ts_lru_cache()
def _open_denormalize_model(
        model_name: str = _DEFAULT_MODEL_NAME,
        denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
):
    """
    Open a denormalization model for WD14 embeddings.

    :param model_name: Name of the model.
    :type model_name: str
    :param denormalizer_name: Name of the denormalizer.
    :type denormalizer_name: str
    :return: The loaded ONNX model.
    :rtype: ONNXModel
    """
    return open_onnx_model(hf_hub_download(
        repo_id='deepghs/embedding_aligner',
        repo_type='model',
        filename=f'{model_name}_{denormalizer_name}/model.onnx',
    ))


def denormalize_wd14_emb(
        emb: np.ndarray,
        model_name: str = _DEFAULT_MODEL_NAME,
        denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
) -> np.ndarray:
    """
    Denormalize WD14 embeddings.

    :param emb: The embedding to denormalize.
    :type emb: numpy.ndarray
    :param model_name: Name of the model.
    :type model_name: str
    :param denormalizer_name: Name of the denormalizer.
    :type denormalizer_name: str
    :return: The denormalized embedding.
    :rtype: numpy.ndarray

    Examples:
        >>> import numpy as np
        >>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction, denormalize_wd14_emb
        ...
        >>> embedding, (r, g, c) = get_wd14_tags(
        ...     'image.png',
        ...     fmt=('embedding', ('rating', 'general', 'character')),
        ... )
        ...
        >>> # normalize embedding
        >>> embedding = embedding / np.linalg.norm(embedding)
        ...
        >>> # denormalize this embedding
        >>> output = denormalize_wd14_emb(embedding)
        ...
        >>> # should be similar to r, g, c, approx 1e-3 error
        >>> rating, general, character = convert_wd14_emb_to_prediction(output)
    """
    model = _open_denormalize_model(
        model_name=model_name,
        denormalizer_name=denormalizer_name,
    )
    emb = emb / np.linalg.norm(emb, axis=-1, keepdims=True)
    if len(emb.shape) == 1:
        output, = model.run(output_names=['embedding'], input_feed={'input': emb[None, ...]})
        return output[0]
    else:
        embedding_width = model.get_outputs()[0].shape[-1]
        origin_shape = emb.shape
        emb = emb.reshape(-1, embedding_width)
        output, = model.run(output_names=['embedding'], input_feed={'input': emb})
        return output.reshape(*origin_shape)
+57 −1
Original line number Diff line number Diff line
import numpy as np
import pytest

from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction
from imgutils.tagging.wd14 import _get_wd14_model
from imgutils.tagging.wd14 import _get_wd14_model, denormalize_wd14_emb, _open_denormalize_model
from test.testings import get_testfile


@@ -11,6 +12,7 @@ def _release_model_after_run():
        yield
    finally:
        _get_wd14_model.cache_clear()
        _open_denormalize_model.cache_clear()


@pytest.mark.unittest
@@ -173,3 +175,57 @@ class TestTaggingWd14:
        assert rating == pytest.approx(expected_rating, abs=2e-3)
        assert general == pytest.approx(expected_general, abs=2e-3)
        assert character == pytest.approx(expected_character, abs=2e-3)

    @pytest.mark.parametrize(['file'], [
        ('nude_girl.png',),
    ])
    def test_convert_wd14_emb_to_prediction_denormalize(self, file):
        file = get_testfile(file)
        (expected_rating, expected_general, expected_character), embedding = \
            get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding'))

        embedding = embedding / np.linalg.norm(embedding)
        rating, general, character = convert_wd14_emb_to_prediction(embedding, denormalize=True)
        assert rating == pytest.approx(expected_rating, abs=1e-2)
        assert general == pytest.approx(expected_general, abs=1e-2)
        assert character == pytest.approx(expected_character, abs=1e-2)

    @pytest.mark.parametrize(['file'], [
        ('nude_girl.png',),
        # ('nian.png',),  # some low scores not match
    ])
    def test_denormalize_wd14_emb(self, file):
        file = get_testfile(file)
        (expected_rating, expected_general, expected_character), embedding = \
            get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding'))

        embedding = embedding / np.linalg.norm(embedding)
        output = denormalize_wd14_emb(embedding)
        rating, general, character = convert_wd14_emb_to_prediction(output)
        assert rating == pytest.approx(expected_rating, abs=1e-2)
        assert general == pytest.approx(expected_general, abs=1e-2)
        assert character == pytest.approx(expected_character, abs=1e-2)

    @pytest.mark.parametrize(['files'], [
        (['nude_girl.png'],),
        (['nude_girl.png', 'nude_girl.png'],),
        # ('nian.png',),  # some low scores not match
    ])
    def test_denormalize_wd14_emb_multiple(self, files):
        files = [get_testfile(file) for file in files]
        expected = []
        embeddings = []
        for file in files:
            (expected_rating, expected_general, expected_character), embedding = \
                get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding'))
            expected.append((expected_rating, expected_general, expected_character))
            embeddings.append(embedding / np.linalg.norm(embedding))

        embeddings = np.stack(embeddings)
        outputs = denormalize_wd14_emb(embeddings)
        actual = convert_wd14_emb_to_prediction(outputs)
        for (expected_rating, expected_general, expected_character), \
                (rating, general, character) in zip(expected, actual):
            assert rating == pytest.approx(expected_rating, abs=1e-2)
            assert general == pytest.approx(expected_general, abs=1e-2)
            assert character == pytest.approx(expected_character, abs=1e-2)