Commit f2697942 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add more processing method for tags

parent 8c204738
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -8,8 +8,11 @@ Overview:
        :align: center

"""
from .blacklist import is_blacklisted, drop_blacklisted_tags
from .character import is_basic_character_tag, drop_basic_character_tags
from .deepdanbooru import get_deepdanbooru_tags
from .format import tags_to_text
from .mldanbooru import get_mldanbooru_tags
from .order import sort_tags
from .overlap import drop_overlap_tags, drop_overlaps_for_dict
from .wd14 import get_wd14_tags
+49 −0
Original line number Diff line number Diff line
from functools import lru_cache
from typing import Union, List, Mapping, Set, Optional

from huggingface_hub import hf_hub_download


@lru_cache()
def _load_online_blacklist() -> List[str]:
    with open(hf_hub_download(
            'alea31415/tag_filtering',
            'blacklist_tags.txt',
            repo_type='dataset',
    ), 'r') as f:
        return [line.strip() for line in f if line.strip()]


def _is_blacklisted(tag: str, blacklist: Set[str]):
    return (tag in blacklist or
            tag.replace('_', ' ') in blacklist or
            tag.replace(' ', '_') in blacklist)


@lru_cache()
def _online_blacklist_set() -> Set[str]:
    return set(_load_online_blacklist())


def is_blacklisted(tags: str):
    return _is_blacklisted(tags, _online_blacklist_set())


def drop_blacklisted_tags(tags: Union[List[str], Mapping[str, float]],
                          use_presets: bool = True, custom_blacklist: Optional[List[str]] = None) \
        -> Union[List[str], Mapping[str, float]]:
    blacklist = []
    if use_presets:
        blacklist.extend(_load_online_blacklist())
    blacklist.extend(custom_blacklist or [])

    blacklist = set(tag.replace(' ', '_') for tag in blacklist)
    blacklist_update = set(tag.replace('_', ' ') for tag in blacklist)
    blacklist.update(blacklist_update)

    if isinstance(tags, dict):
        return {tag: value for tag, value in tags.items() if not _is_blacklisted(tag, blacklist)}
    elif isinstance(tags, list):
        return [tag for tag in tags if not _is_blacklisted(tag, blacklist)]
    else:
        raise TypeError(f"Unsupported types of tags, dict or list expected, but {tags!r} found.")
+70 −0
Original line number Diff line number Diff line
import re
from typing import Union, List, Mapping

from hbutils.string import singular_form, plural_form

_CHAR_WHITELIST = [
    'drill', 'pubic_hair', 'closed_eyes', 'half-closed_eyes', 'empty_eyes'
]
_CHAR_SUFFIXES = [
    'eyes', 'skin', 'hair', 'bun', 'bangs', 'cut', 'sidelocks',
    'twintails', 'braid', 'braids', 'afro', 'ahoge', 'drill',
    'drills', 'bald', 'dreadlocks', 'side up', 'ponytail', 'updo',
    'beard', 'mustache', 'pointy ears', 'ear', 'horn',
]
_CHAR_PREFIXES = [
    'hair over', 'hair between'
]


def _split_to_words(text: str) -> List[str]:
    return [word.lower() for word in re.split(r'[\W_]+', text) if word]


def _match_suffix(tag: str, suffix: str):
    tag_words = _split_to_words(tag)
    suffix_words = _split_to_words(suffix)
    all_suffixes = [suffix_words]
    all_suffixes.append([*suffix_words[:-1], singular_form(suffix_words[0])])
    all_suffixes.append([*suffix_words[:-1], plural_form(suffix_words[0])])

    for suf in all_suffixes:
        if tag_words[-len(suf):] == suf:
            return True

    return False


def _match_prefix(tag: str, prefix: str):
    tag_words = _split_to_words(tag)
    prefix_words = _split_to_words(prefix)
    return tag_words[:len(prefix_words)] == prefix_words


def _match_same(tag: str, expected: str):
    a = _split_to_words(tag)
    as_ = [a, [*a[:-1], singular_form(a[-1])], [*a[:-1], plural_form(a[-1])]]
    as_ = set([tuple(item) for item in as_])

    b = _split_to_words(expected)
    bs_ = [b, [*b[:-1], singular_form(b[-1])], [*b[:-1], plural_form(b[-1])]]
    bs_ = set([tuple(item) for item in bs_])

    return bool(as_ & bs_)


def is_basic_character_tag(tag: str) -> bool:
    if any(_match_same(tag, wl_tag) for wl_tag in _CHAR_WHITELIST):
        return False
    else:
        return (any(_match_suffix(tag, suffix) for suffix in _CHAR_SUFFIXES)
                or any(_match_prefix(tag, prefix) for prefix in _CHAR_PREFIXES))


def drop_basic_character_tags(tags: Union[List[str], Mapping[str, float]]) -> Union[List[str], Mapping[str, float]]:
    if isinstance(tags, dict):
        return {tag: value for tag, value in tags.items() if not is_basic_character_tag(tag)}
    elif isinstance(tags, list):
        return [tag for tag in tags if not is_basic_character_tag(tag)]
    else:
        raise TypeError(f"Unsupported types of tags, dict or list expected, but {tags!r} found.")
+41 −0
Original line number Diff line number Diff line
import random
import re
from typing import Union, List, Mapping

try:
    from typing import Literal
except (ImportError, ModuleNotFoundError):
    from typing_extensions import Literal


def sort_tags(tags: Union[List[str], Mapping[str, float]],
              mode: Literal['original', 'shuffle', 'score'] = 'score') -> List[str]:
    if mode not in {'original', 'shuffle', 'score'}:
        raise ValueError(f'Unknown sort_mode, \'original\', '
                         f'\'shuffle\' or \'score\' expected but {mode!r} found.')
    npeople_tags = []
    remaining_tags = []

    if 'solo' in tags:
        npeople_tags.append('solo')

    for tag in tags:
        if tag == 'solo':
            continue
        if re.fullmatch(r'^\d+\+?(boy|girl)s?$', tag):  # 1girl, 1boy, 2girls, 3boys, 9+girls
            npeople_tags.append(tag)
        else:
            remaining_tags.append(tag)

    if mode == 'score':
        if isinstance(tags, dict):
            remaining_tags = sorted(remaining_tags, key=lambda x: -tags[x])
        else:
            raise TypeError(f'Sort mode {mode!r} not supported for list, '
                            f'for it do not have scores.')
    elif mode == 'shuffle':
        random.shuffle(remaining_tags)
    else:
        pass

    return npeople_tags + remaining_tags
+23 −0
Original line number Diff line number Diff line
import pytest


@pytest.fixture()
def complex_dict_tags():
    return {
        '1girl': 0.998362123966217, 'solo': 0.9912548065185547, 'long_hair': 0.9401906728744507,
        'breasts': 0.983635425567627, 'looking_at_viewer': 0.9146994352340698, 'blush': 0.8892400860786438,
        'smile': 0.43393653631210327, 'bangs': 0.49712443351745605, 'large_breasts': 0.5196534395217896,
        'navel': 0.9653235077857971, 'hair_between_eyes': 0.5786703824996948, 'very_long_hair': 0.8142435550689697,
        'closed_mouth': 0.9369247555732727, 'nipples': 0.9660118222236633, 'purple_eyes': 0.9676010012626648,
        'collarbone': 0.588348925113678, 'nude': 0.9496222734451294, 'red_hair': 0.9200156331062317,
        'sweat': 0.8690457344055176, 'horns': 0.9711267948150635, 'pussy': 0.9868264198303223,
        'spread_legs': 0.9603149890899658, 'armpits': 0.9024748802185059, 'stomach': 0.6723923087120056,
        'arms_up': 0.9380699396133423, 'completely_nude': 0.9002960920333862, 'uncensored': 0.8612104058265686,
        'pussy_juice': 0.6021570563316345, 'feet_out_of_frame': 0.39779460430145264, 'on_bed': 0.610720157623291,
        'arms_behind_head': 0.44814401865005493, 'breasts_apart': 0.39798974990844727, 'clitoris': 0.5310801267623901,
    }


@pytest.fixture()
def complex_list_tags(complex_dict_tags):
    return list(complex_dict_tags.keys())
Loading