Commit 617289a3 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): optimize tag check system

parent a88a4892
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -17,4 +17,5 @@ imgutils.tagging
    blacklist
    character
    order
    match
+29 −0
Original line number Diff line number Diff line
imgutils.tagging.match
====================================

.. currentmodule:: imgutils.tagging.match

.. automodule:: imgutils.tagging.match


tag_match_suffix
-------------------------------------------------

.. autofunction:: tag_match_suffix



tag_match_prefix
-------------------------------------------------

.. autofunction:: tag_match_prefix



tag_match_full
-------------------------------------------------

.. autofunction:: tag_match_full


+1 −0
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ from .blacklist import is_blacklisted, drop_blacklisted_tags
from .character import is_basic_character_tag, drop_basic_character_tags
from .deepdanbooru import get_deepdanbooru_tags
from .format import tags_to_text
from .match import tag_match_suffix, tag_match_prefix, tag_match_full
from .mldanbooru import get_mldanbooru_tags
from .order import sort_tags
from .overlap import drop_overlap_tags, drop_overlaps_for_dict
+21 −30
Original line number Diff line number Diff line
@@ -3,10 +3,12 @@ Overview:
    Detect and drop some blacklisted tags, which are listed `here <https://huggingface.co/datasets/alea31415/tag_filtering/blob/main/blacklist_tags.txt>`_.
"""
from functools import lru_cache
from typing import Union, List, Mapping, Set, Optional
from typing import Union, List, Mapping, Set, Optional, Tuple

from huggingface_hub import hf_hub_download

from .match import _words_to_matcher, _split_to_words


@lru_cache()
def _load_online_blacklist() -> List[str]:
@@ -24,39 +26,31 @@ def _load_online_blacklist() -> List[str]:
        return [line.strip() for line in f if line.strip()]


def _is_blacklisted(tag: str, blacklist: Set[str]):
    """
    Check if a tag is blacklisted.

    :param tag: Tag to be checked.
    :type tag: str
    :param blacklist: Set of blacklisted tags.
    :type blacklist: Set[str]
    :return: True if the tag is blacklisted, False otherwise.
    :rtype: bool
    """
    return (tag in blacklist or
            tag.replace('_', ' ') in blacklist or
            tag.replace(' ', '_') in blacklist)


@lru_cache()
def _online_blacklist_set() -> Set[str]:
def _online_blacklist_set() -> Set[Tuple[str, ...]]:
    """
    Get the online blacklist as a set.

    :return: Set of blacklisted tags.
    :rtype: Set[str]
    """
    return set(_load_online_blacklist())
    set_ = set()
    for tag in _load_online_blacklist():
        set_ = set_ | _words_to_matcher(_split_to_words(tag))
    return set_


def _is_blacklisted(tag: str, blacklist_set: Set[Tuple[str, ...]]) -> bool:
    _tag_matcher = _words_to_matcher(_split_to_words(tag))
    return bool(set(_tag_matcher & blacklist_set))

def is_blacklisted(tags: str):

def is_blacklisted(tag: str) -> bool:
    """
    Check if any of the given tags are blacklisted.

    :param tags: Tags to be checked.
    :type tags: str
    :param tag: Tags to be checked.
    :type tag: str
    :return: True if any tag is blacklisted, False otherwise.
    :rtype: bool

@@ -72,7 +66,7 @@ def is_blacklisted(tags: str):
        >>> is_blacklisted('red_hair')
        False
    """
    return _is_blacklisted(tags, _online_blacklist_set())
    return _is_blacklisted(tag, _online_blacklist_set())


def drop_blacklisted_tags(tags: Union[List[str], Mapping[str, float]],
@@ -102,14 +96,11 @@ def drop_blacklisted_tags(tags: Union[List[str], Mapping[str, float]],
        >>> drop_blacklisted_tags(['solo', '1girl', 'cosplay', 'no_eyewear'])
        ['solo', '1girl']
    """
    blacklist = []
    blacklist = set()
    if use_presets:
        blacklist.extend(_load_online_blacklist())
    blacklist.extend(custom_blacklist or [])

    blacklist = set(tag.replace(' ', '_') for tag in blacklist)
    blacklist_update = set(tag.replace('_', ' ') for tag in blacklist)
    blacklist.update(blacklist_update)
        blacklist = blacklist | _online_blacklist_set()
    for tag in (custom_blacklist or []):
        blacklist = blacklist | _words_to_matcher(_split_to_words(tag))

    if isinstance(tags, dict):
        return {tag: value for tag, value in tags.items() if not _is_blacklisted(tag, blacklist)}
+4 −79
Original line number Diff line number Diff line
@@ -2,10 +2,9 @@
Overview:
    Detect and drop character-related basic tags.
"""
import re
from typing import Union, List, Mapping

from hbutils.string import singular_form, plural_form
from .match import tag_match_full, tag_match_prefix, tag_match_suffix

_CHAR_WHITELIST = [
    'drill', 'pubic_hair', 'closed_eyes', 'half-closed_eyes', 'empty_eyes'
@@ -21,80 +20,6 @@ _CHAR_PREFIXES = [
]


def _split_to_words(text: str) -> List[str]:
    """
    Split a string into words and return them in lowercase.

    :param text: The input text to split.
    :type text: str
    :return: List of lowercase words.
    :rtype: List[str]
    """
    return [word.lower() for word in re.split(r'[\W_]+', text) if word]


def _match_suffix(tag: str, suffix: str):
    """
    Check if a tag matches a given suffix.

    :param tag: The tag to check.
    :type tag: str
    :param suffix: The suffix to match.
    :type suffix: str
    :return: True if the tag matches the suffix, False otherwise.
    :rtype: bool
    """
    tag_words = _split_to_words(tag)
    suffix_words = _split_to_words(suffix)
    all_suffixes = [suffix_words]
    all_suffixes.append([*suffix_words[:-1], singular_form(suffix_words[0])])
    all_suffixes.append([*suffix_words[:-1], plural_form(suffix_words[0])])

    for suf in all_suffixes:
        if tag_words[-len(suf):] == suf:
            return True

    return False


def _match_prefix(tag: str, prefix: str):
    """
    Check if a tag matches a given prefix.

    :param tag: The tag to check.
    :type tag: str
    :param prefix: The prefix to match.
    :type prefix: str
    :return: True if the tag matches the prefix, False otherwise.
    :rtype: bool
    """
    tag_words = _split_to_words(tag)
    prefix_words = _split_to_words(prefix)
    return tag_words[:len(prefix_words)] == prefix_words


def _match_same(tag: str, expected: str):
    """
    Check if a tag matches another tag, considering singular and plural forms.

    :param tag: The tag to check.
    :type tag: str
    :param expected: The expected tag.
    :type expected: str
    :return: True if the tag matches the expected tag, False otherwise.
    :rtype: bool
    """
    a = _split_to_words(tag)
    as_ = [a, [*a[:-1], singular_form(a[-1])], [*a[:-1], plural_form(a[-1])]]
    as_ = set([tuple(item) for item in as_])

    b = _split_to_words(expected)
    bs_ = [b, [*b[:-1], singular_form(b[-1])], [*b[:-1], plural_form(b[-1])]]
    bs_ = set([tuple(item) for item in bs_])

    return bool(as_ & bs_)


def is_basic_character_tag(tag: str) -> bool:
    """
    Check if a tag is a basic character tag by matching with predefined whitelisted and blacklisted patterns.
@@ -122,11 +47,11 @@ def is_basic_character_tag(tag: str) -> bool:
        >>> is_basic_character_tag('dress')
        False
    """
    if any(_match_same(tag, wl_tag) for wl_tag in _CHAR_WHITELIST):
    if any(tag_match_full(tag, wl_tag) for wl_tag in _CHAR_WHITELIST):
        return False
    else:
        return (any(_match_suffix(tag, suffix) for suffix in _CHAR_SUFFIXES)
                or any(_match_prefix(tag, prefix) for prefix in _CHAR_PREFIXES))
        return (any(tag_match_suffix(tag, suffix) for suffix in _CHAR_SUFFIXES)
                or any(tag_match_prefix(tag, prefix) for prefix in _CHAR_PREFIXES))


def drop_basic_character_tags(tags: Union[List[str], Mapping[str, float]]) -> Union[List[str], Mapping[str, float]]:
Loading