Commit 15c0b432 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add denormalize and unittests

parent 510b6cc3
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -16,4 +16,4 @@ from .match import tag_match_suffix, tag_match_prefix, tag_match_full
from .mldanbooru import get_mldanbooru_tags
from .order import sort_tags
from .overlap import drop_overlap_tags
from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction
from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction, denormalize_wd14_emb
+74 −10
Original line number Diff line number Diff line
@@ -214,6 +214,9 @@ def _postprocess_embedding(
    :param fmt: The format of the output.
    :return: The post-processed results.
    """
    assert len(pred.shape) == len(embedding.shape) == 1, \
        f'Both pred and embeddings shapes should be 1-dim, ' \
        f'but pred: {pred.shape!r}, embedding: {embedding.shape!r} actually found.'
    tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline)
    labels = list(zip(tag_names, pred.astype(float)))

@@ -356,6 +359,9 @@ def get_wd14_tags(
    )


_DEFAULT_DENORMALIZER_NAME = 'mnum2_all'


def convert_wd14_emb_to_prediction(
        emb: np.ndarray,
        model_name: str = _DEFAULT_MODEL_NAME,
@@ -366,6 +372,8 @@ def convert_wd14_emb_to_prediction(
        no_underline: bool = False,
        drop_overlap: bool = False,
        fmt=('rating', 'general', 'character'),
        denormalize: bool = False,
        denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
):
    """
    Convert WD14 embedding to understandable prediction result.
@@ -403,9 +411,17 @@ def convert_wd14_emb_to_prediction(
        >>> rating, general, character = convert_wd14_emb_to_prediction(embedding)
        >>> # these 3 dicts will be the same as that returned by `get_wd14_tags('skadi.jpg')`
    """
    if denormalize:
        emb = denormalize_wd14_emb(
            emb=emb,
            model_name=model_name,
            denormalizer_name=denormalizer_name,
        )

    z_weights = _get_wd14_weights(model_name)
    weights, bias = z_weights['weights'], z_weights['bias']
    pred = sigmoid(emb @ weights + bias)
    if len(emb.shape) == 1:
        return _postprocess_embedding(
            pred=pred,
            embedding=emb,
@@ -418,3 +434,51 @@ def convert_wd14_emb_to_prediction(
            drop_overlap=drop_overlap,
            fmt=fmt,
        )
    else:
        return [
            _postprocess_embedding(
                pred=pred_item,
                embedding=emb_item,
                model_name=model_name,
                general_threshold=general_threshold,
                general_mcut_enabled=general_mcut_enabled,
                character_threshold=character_threshold,
                character_mcut_enabled=character_mcut_enabled,
                no_underline=no_underline,
                drop_overlap=drop_overlap,
                fmt=fmt,
            )
            for pred_item, emb_item in zip(pred, emb)
        ]


@ts_lru_cache()
def _open_denormalize_model(
        model_name: str = _DEFAULT_MODEL_NAME,
        denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
):
    return open_onnx_model(hf_hub_download(
        repo_id='deepghs/embedding_aligner',
        repo_type='model',
        filename=f'{model_name}_{denormalizer_name}/model.onnx',
    ))


def denormalize_wd14_emb(
        emb: np.ndarray,
        model_name: str = _DEFAULT_MODEL_NAME,
        denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
) -> np.ndarray:
    model = _open_denormalize_model(
        model_name=model_name,
        denormalizer_name=denormalizer_name,
    )
    if len(emb.shape) == 1:
        output, = model.run(output_names=['embedding'], input_feed={'input': emb[None, ...]})
        return output[0]
    else:
        embedding_width = model.get_outputs()[0].shape[-1]
        origin_shape = emb.shape
        emb = emb.reshape(-1, embedding_width)
        output, = model.run(output_names=['embedding'], input_feed={'input': emb})
        return output.reshape(*origin_shape)
+57 −1
Original line number Diff line number Diff line
import numpy as np
import pytest

from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction
from imgutils.tagging.wd14 import _get_wd14_model
from imgutils.tagging.wd14 import _get_wd14_model, denormalize_wd14_emb, _open_denormalize_model
from test.testings import get_testfile


@@ -11,6 +12,7 @@ def _release_model_after_run():
        yield
    finally:
        _get_wd14_model.cache_clear()
        _open_denormalize_model.cache_clear()


@pytest.mark.unittest
@@ -173,3 +175,57 @@ class TestTaggingWd14:
        assert rating == pytest.approx(expected_rating, abs=2e-3)
        assert general == pytest.approx(expected_general, abs=2e-3)
        assert character == pytest.approx(expected_character, abs=2e-3)

    @pytest.mark.parametrize(['file'], [
        ('nude_girl.png',),
    ])
    def test_convert_wd14_emb_to_prediction_denormalize(self, file):
        file = get_testfile(file)
        (expected_rating, expected_general, expected_character), embedding = \
            get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding'))

        embedding = embedding / np.linalg.norm(embedding)
        rating, general, character = convert_wd14_emb_to_prediction(embedding, denormalize=True)
        assert rating == pytest.approx(expected_rating, abs=1e-2)
        assert general == pytest.approx(expected_general, abs=1e-2)
        assert character == pytest.approx(expected_character, abs=1e-2)

    @pytest.mark.parametrize(['file'], [
        ('nude_girl.png',),
        # ('nian.png',),  # some low scores not match
    ])
    def test_denormalize_wd14_emb(self, file):
        file = get_testfile(file)
        (expected_rating, expected_general, expected_character), embedding = \
            get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding'))

        embedding = embedding / np.linalg.norm(embedding)
        output = denormalize_wd14_emb(embedding)
        rating, general, character = convert_wd14_emb_to_prediction(output)
        assert rating == pytest.approx(expected_rating, abs=1e-2)
        assert general == pytest.approx(expected_general, abs=1e-2)
        assert character == pytest.approx(expected_character, abs=1e-2)

    @pytest.mark.parametrize(['files'], [
        (['nude_girl.png'],),
        (['nude_girl.png', 'nude_girl.png'],),
        # ('nian.png',),  # some low scores not match
    ])
    def test_denormalize_wd14_emb_multiple(self, files):
        files = [get_testfile(file) for file in files]
        expected = []
        embeddings = []
        for file in files:
            (expected_rating, expected_general, expected_character), embedding = \
                get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding'))
            expected.append((expected_rating, expected_general, expected_character))
            embeddings.append(embedding / np.linalg.norm(embedding))

        embeddings = np.stack(embeddings)
        outputs = denormalize_wd14_emb(embeddings)
        actual = convert_wd14_emb_to_prediction(outputs)
        for (expected_rating, expected_general, expected_character), \
                (rating, general, character) in zip(expected, actual):
            assert rating == pytest.approx(expected_rating, abs=1e-2)
            assert general == pytest.approx(expected_general, abs=1e-2)
            assert character == pytest.approx(expected_character, abs=1e-2)