Commit e9bf20b6 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add unittest for new convert function

parent eb173500
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -16,4 +16,4 @@ from .match import tag_match_suffix, tag_match_prefix, tag_match_full
from .mldanbooru import get_mldanbooru_tags
from .order import sort_tags
from .overlap import drop_overlap_tags
from .wd14 import get_wd14_tags
from .wd14 import get_wd14_tags, convert_wd14_emb_to_prediction
+15 −1
Original line number Diff line number Diff line
import pytest

from imgutils.tagging import get_wd14_tags
from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction
from imgutils.tagging.wd14 import _get_wd14_model
from test.testings import get_testfile

@@ -159,3 +159,17 @@ class TestTaggingWd14:
            'tube_top': 0.9783295392990112, 'bead_bracelet': 0.3510066270828247, 'red_bandeau': 0.8741766214370728
        }, abs=2e-2)
        assert chars == pytest.approx({'nian_(arknights)': 0.9968841671943665}, abs=2e-2)

    @pytest.mark.parametrize(['file'], [
        ('nude_girl.png',),
        ('nian.png',),
    ])
    def test_convert_wd14_emb_to_prediction(self, file):
        file = get_testfile(file)
        (expected_rating, expected_general, expected_character), embedding = \
            get_wd14_tags(file, fmt=(('rating', 'general', 'character'), 'embedding'))

        rating, general, character = convert_wd14_emb_to_prediction(embedding)
        assert rating == pytest.approx(expected_rating, abs=2e-3)
        assert general == pytest.approx(expected_general, abs=2e-3)
        assert character == pytest.approx(expected_character, abs=2e-3)