Unverified Commit 601dc13a authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #73 from deepghs/dev/underline

dev(narugo): add underline processing
parents eda6d0ba c828fc44
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -6,6 +6,18 @@ imgutils.tagging.format
.. automodule:: imgutils.tagging.format


add_underline
---------------------------

.. autofunction:: add_underline


remove_underline
---------------------------

.. autofunction:: remove_underline


tags_to_text
---------------------------

+1 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ Overview:
from .blacklist import is_blacklisted, drop_blacklisted_tags
from .character import is_basic_character_tag, drop_basic_character_tags
from .deepdanbooru import get_deepdanbooru_tags
from .format import tags_to_text
from .format import tags_to_text, add_underline, remove_underline
from .match import tag_match_suffix, tag_match_prefix, tag_match_full
from .mldanbooru import get_mldanbooru_tags
from .order import sort_tags
+47 −1
Original line number Diff line number Diff line
@@ -6,6 +6,52 @@ import re
from typing import Mapping

RE_SPECIAL = re.compile(r'([\\()])')
_KAOMOJIS = [
    "0_0",
    "(o)_(o)",
    "+_+",
    "+_-",
    "._.",
    "<o>_<o>",
    "<|>_<|>",
    "=_=",
    ">_<",
    "3_3",
    "6_9",
    ">_o",
    "@_@",
    "^_^",
    "o_o",
    "u_u",
    "x_x",
    "|_|",
    "||_||",
]


def add_underline(tag):
    """
    Adds underscores to a tag string to make it compatible with image labeling conventions.

    :param tag: The input tag string.
    :type tag: str
    :return: The tag string with underscores added.
    :rtype: str
    """
    return tag.strip().replace(' ', '_')


def remove_underline(tag):
    """
    Removes underscores from a tag string, restoring it to its original form.

    :param tag: The input tag string.
    :type tag: str
    :return: The tag string with underscores removed.
    :rtype: str
    """
    tag = tag.strip()
    return tag.replace('_', ' ') if tag not in _KAOMOJIS else tag


def tags_to_text(tags: Mapping[str, float],
@@ -47,7 +93,7 @@ def tags_to_text(tags: Mapping[str, float],
    for tag, score in tags_pairs:
        t_text = tag
        if use_spaces:
            t_text = t_text.replace('_', ' ')
            t_text = remove_underline(t_text)
        if use_escape:
            t_text = re.sub(RE_SPECIAL, r'\\\1', t_text)
        if include_score:
+2 −26
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ import pandas as pd
from PIL import Image
from hbutils.testing.requires.version import VersionInfo

from .format import remove_underline
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model
@@ -49,29 +50,6 @@ def _version_support_check(model_name):
                               f'please upgrade it to 1.17+ version.')  # pragma: no cover


_KAOMOJIS = [
    "0_0",
    "(o)_(o)",
    "+_+",
    "+_-",
    "._.",
    "<o>_<o>",
    "<|>_<|>",
    "=_=",
    ">_<",
    "3_3",
    "6_9",
    ">_o",
    "@_@",
    "^_^",
    "o_o",
    "u_u",
    "x_x",
    "|_|",
    "||_||",
]


@lru_cache()
def _get_wd14_model(model_name):
    """
@@ -102,9 +80,7 @@ def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str],
    df = pd.read_csv(path)
    name_series = df["name"]
    if no_underline:
        name_series = name_series.map(
            lambda x: x.replace("_", " ") if x not in _KAOMOJIS else x
        )
        name_series = name_series.map(remove_underline)
    tag_names = name_series.tolist()

    rating_indexes = list(np.where(df["category"] == 9)[0])
+13 −1
Original line number Diff line number Diff line
import pytest

from imgutils.tagging import tags_to_text
from imgutils.tagging import tags_to_text, add_underline, remove_underline


@pytest.fixture()
@@ -40,3 +40,15 @@ class TestTaggingFormat:
               '1girl, panties, drinking glass, panty pull, areola slip'
        assert tags_to_text(tag_mapping, use_spaces=True, include_score=True) == \
               '(1girl:0.999), (panties:0.959), (drinking glass:0.934), (panty pull:0.683), (areola slip:0.412)'

    def test_remove_underline(self):
        assert remove_underline('1girl') == '1girl'
        assert remove_underline(' red hair ') == 'red hair'
        assert remove_underline('red_hair ') == 'red hair'
        assert remove_underline('  ||_|| ') == '||_||'

    def test_add_underline(self):
        assert add_underline('1girl') == '1girl'
        assert add_underline('red hair') == 'red_hair'
        assert add_underline(' red hair  ') == 'red_hair'
        assert add_underline(' ||_||  ') == '||_||'