Commit b9e5bfc2 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add unittest for wd tagger attachments

parent c1cb9439
Loading
Loading
Loading
Loading
+52 −13
Original line number Diff line number Diff line
@@ -7,7 +7,8 @@ Overview:
    project on Hugging Face.

"""
from typing import List, Tuple, Any
from collections import defaultdict
from typing import List, Tuple, Any, Optional, Mapping, Dict, Union

import numpy as np
import onnxruntime
@@ -19,7 +20,8 @@ from huggingface_hub import hf_hub_download
from .format import remove_underline
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache
from ..generic.attachment import open_attachment, Attachment
from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache, vnames

SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
@@ -189,6 +191,7 @@ def _postprocess_embedding(
        no_underline: bool = False,
        drop_overlap: bool = False,
        fmt: Any = ('rating', 'general', 'character'),
        attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] = None,
):
    """
    Post-process the embedding and prediction results.
@@ -214,6 +217,27 @@ def _postprocess_embedding(
    :param fmt: The format of the output.
    :return: The post-processed results.
    """
    attachments = dict(attachments or {})
    d_attachments: Dict[str, Tuple[Attachment, dict]] = {}
    for attach_name, attach_tpl in attachments.items():
        if '/' in attach_name:
            raise ValueError(f'Invalid attachment register name, no \'/\' required - {attach_name!r}.')

        if len(attach_tpl) == 2:
            (attach_repo_id, attach_model_name), attach_kwargs = attach_tpl, {}
        elif len(attach_tpl) == 3:
            attach_repo_id, attach_model_name, attach_kwargs = attach_tpl
        else:
            raise ValueError(f'Invalid attachment tuple for {attach_name!r}, '
                             f'2 or 3 elements expected but {attach_tpl!r} found.')
        attachment = open_attachment(repo_id=attach_repo_id, model_name=attach_model_name)
        expected_encoder_model = f'wdtagger:{MODEL_NAMES[model_name]}'
        if attachment.encoder_model != expected_encoder_model:
            raise ValueError(f'Attachment encoder model not match, '
                             f'{expected_encoder_model!r} expected but {attachment.encoder_model!r} found '
                             f'for {attach_name!r}.')
        d_attachments[attach_name] = (attachment, attach_kwargs)

    assert len(pred.shape) == len(embedding.shape) == 1, \
        f'Both pred and embeddings shapes should be 1-dim, ' \
        f'but pred: {pred.shape!r}, embedding: {embedding.shape!r} actually found.'
@@ -239,9 +263,7 @@ def _postprocess_embedding(

    character_res = {x: v.item() for x, v in character_names if v > character_threshold}

    return vreplace(
        fmt,
        {
    mapping_values = {
        'rating': rating,
        'general': general_res,
        'character': character_res,
@@ -249,7 +271,19 @@ def _postprocess_embedding(
        'embedding': embedding.astype(np.float32),
        'prediction': pred.astype(np.float32),
    }
    )

    d_attach_infers = defaultdict(list)
    for vname in vnames(fmt):
        if '/' in vname and vname.split('/', maxsplit=1)[0] in d_attachments:
            attach_name, attach_fmt_name = vname.split('/', maxsplit=1)
            d_attach_infers[attach_name].append(attach_fmt_name)
    for attach_name, attach_infer_names in d_attach_infers.items():
        attachment, attach_kwargs = d_attachments[attach_name]
        attach_infer_values = attachment.predict(embedding=embedding, fmt=attach_infer_names, **attach_kwargs)
        attach_mapping_names = [f'{attach_name}/{name}' for name in attach_infer_names]
        mapping_values.update(dict(zip(attach_mapping_names, attach_infer_values)))

    return vreplace(fmt, mapping_values)


def get_wd14_tags(
@@ -262,6 +296,7 @@ def get_wd14_tags(
        no_underline: bool = False,
        drop_overlap: bool = False,
        fmt: Any = ('rating', 'general', 'character'),
        attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] = None,
):
    """
    Get tags for an image using WD14 taggers.
@@ -356,6 +391,7 @@ def get_wd14_tags(
        no_underline=no_underline,
        drop_overlap=drop_overlap,
        fmt=fmt,
        attachments=attachments,
    )


@@ -372,6 +408,7 @@ def convert_wd14_emb_to_prediction(
        no_underline: bool = False,
        drop_overlap: bool = False,
        fmt: Any = ('rating', 'general', 'character'),
        attachments: Optional[Mapping[str, Union[Tuple[str, str], Tuple[str, str, dict]]]] = None,
        denormalize: bool = False,
        denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
):
@@ -453,6 +490,7 @@ def convert_wd14_emb_to_prediction(
            no_underline=no_underline,
            drop_overlap=drop_overlap,
            fmt=fmt,
            attachments=attachments,
        )
    else:
        return [
@@ -467,6 +505,7 @@ def convert_wd14_emb_to_prediction(
                no_underline=no_underline,
                drop_overlap=drop_overlap,
                fmt=fmt,
                attachments=attachments,
            )
            for pred_item, emb_item in zip(pred, emb)
        ]
+91 −0
Original line number Diff line number Diff line
import numpy as np
import pytest
from hbutils.testing import tmatrix

from imgutils.generic.attachment import open_attachment
from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction
from imgutils.tagging.wd14 import _get_wd14_model, denormalize_wd14_emb, _open_denormalize_model
from test.testings import get_testfile
@@ -13,6 +15,7 @@ def _release_model_after_run():
    finally:
        _get_wd14_model.cache_clear()
        _open_denormalize_model.cache_clear()
        open_attachment.cache_clear()


@pytest.mark.unittest
@@ -229,3 +232,91 @@ class TestTaggingWd14:
            assert rating == pytest.approx(expected_rating, abs=1e-2)
            assert general == pytest.approx(expected_general, abs=1e-2)
            assert character == pytest.approx(expected_character, abs=1e-2)

    @pytest.mark.parametrize(*tmatrix({
        ('type_', 'file'): [
            ('monochrome', '6130053.jpg'),
            ('monochrome', '6125854(第 3 个复件).jpg'),
            ('monochrome', '5221834.jpg'),
            ('monochrome', '1951253.jpg'),
            ('monochrome', '4879658.jpg'),
            ('monochrome', '80750471_p3_master1200.jpg'),

            ('normal', '54566940_p0_master1200.jpg'),
            ('normal', '60817155_p18_master1200.jpg'),
            ('normal', '4945494.jpg'),
            ('normal', '4008375.jpg'),
            ('normal', '2416278.jpg'),
            ('normal', '842709.jpg')
        ],
    }, mode='matrix'))
    def test_get_wd14_tags_with_attachments(self, type_, file):
        filename = get_testfile('dataset', 'monochrome_danbooru', type_, file)
        scores, (top_label, top_score) = get_wd14_tags(
            filename,
            fmt=('monochrome/scores', 'monochrome/top'),
            attachments={'monochrome': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1')},
        )
        assert scores[type_] >= 0.5
        assert top_label == type_
        assert top_score >= 0.5

    @pytest.mark.parametrize(*tmatrix({
        ('type_', 'file'): [
            ('monochrome', '6130053.jpg'),
            ('monochrome', '6125854(第 3 个复件).jpg'),
            ('monochrome', '5221834.jpg'),
            ('monochrome', '1951253.jpg'),
            ('monochrome', '4879658.jpg'),
            ('monochrome', '80750471_p3_master1200.jpg'),

            ('normal', '54566940_p0_master1200.jpg'),
            ('normal', '60817155_p18_master1200.jpg'),
            ('normal', '4945494.jpg'),
            ('normal', '4008375.jpg'),
            ('normal', '2416278.jpg'),
            ('normal', '842709.jpg')
        ],
    }, mode='matrix'))
    def test_get_wd14_tags_with_attachments_extra_cfg(self, type_, file):
        filename = get_testfile('dataset', 'monochrome_danbooru', type_, file)
        scores, (top_label, top_score) = get_wd14_tags(
            filename,
            fmt=('monochrome/scores', 'monochrome/top'),
            attachments={'monochrome': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1', {})},
        )
        assert scores[type_] >= 0.5
        assert top_label == type_
        assert top_score >= 0.5

    @pytest.mark.parametrize(*tmatrix({
        ('type_', 'file'): [
            ('monochrome', '6130053.jpg'),
            ('normal', '54566940_p0_master1200.jpg'),
        ],
    }, mode='matrix'))
    def test_get_wd14_tags_with_attachments_invalid_attachment_name(self, type_, file):
        filename = get_testfile('dataset', 'monochrome_danbooru', type_, file)
        with pytest.raises(ValueError):
            scores, (top_label, top_score) = get_wd14_tags(
                filename,
                fmt=('monochrome/t/scores', 'monochrome/t/top'),
                attachments={'monochrome/t': ('deepghs/eattach_monochrome_experiments', 'mlp_layer1_seed1')},
            )
            _ = top_label, top_score

    @pytest.mark.parametrize(*tmatrix({
        ('type_', 'file'): [
            ('monochrome', '6130053.jpg'),
            ('normal', '54566940_p0_master1200.jpg'),
        ],
    }, mode='matrix'))
    def test_get_wd14_tags_with_attachments_invalid_attachment_config(self, type_, file):
        filename = get_testfile('dataset', 'monochrome_danbooru', type_, file)
        with pytest.raises(ValueError):
            scores, (top_label, top_score) = get_wd14_tags(
                filename,
                fmt=('monochrome/scores', 'monochrome/top'),
                attachments={'monochrome': ('deepghs/eattach_monochrome_experiments',)},
            )
            _ = top_label, top_score