Commit 621a4071 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): a runnable verison of char tags

parent f1b2ef59
Loading
Loading
Loading
Loading
+0 −7
Original line number Diff line number Diff line
@@ -13,10 +13,3 @@ drop_overlap_tags


drop_overlaps_for_dict
----------------------------------

.. autofunction:: drop_overlaps_for_dict


+1 −1
Original line number Diff line number Diff line
@@ -197,7 +197,7 @@ def parse_sdmeta_from_text(x: str) -> SDMetaData:
    x = textwrap.dedent(x).strip()
    *prompt_lines, argument_line = x.splitlines(keepends=False)
    if len(_PARAM_PATTERN.findall(argument_line)) < 3:
        prompt_lines.append(argument_line)
        prompt_lines._append(argument_line)
        argument_line = ''

    # 0x1 means prompt, 0x2 means neg prompt
+86 −16
Original line number Diff line number Diff line
@@ -2,23 +2,102 @@
Overview:
    Detect and drop character-related basic tags.
"""
from typing import Union, List, Mapping
from typing import Union, List, Mapping, Tuple, Dict, Set, Optional

from .match import tag_match_full, tag_match_prefix, tag_match_suffix
from .match import _split_to_words, _words_to_matcher

_CHAR_WHITELIST = [
CHAR_WHITELIST = [
    'drill', 'pubic_hair', 'closed_eyes', 'half-closed_eyes', 'empty_eyes'
]
_CHAR_SUFFIXES = [
CHAR_SUFFIXES = [
    'eyes', 'skin', 'hair', 'bun', 'bangs', 'cut', 'sidelocks',
    'twintails', 'braid', 'braids', 'afro', 'ahoge', 'drill',
    'drills', 'bald', 'dreadlocks', 'side up', 'ponytail', 'updo',
    'beard', 'mustache', 'pointy ears', 'ear', 'horn',
]
_CHAR_PREFIXES = [
CHAR_PREFIXES = [
    'hair over', 'hair between'
]

_WordTupleTyping = Tuple[str, ...]


class _SuffixPool:
    def __init__(self, suffixes: Optional[List[str]] = None):
        self._suffixes: Dict[int, Set[_WordTupleTyping]] = {}
        for suffix in (suffixes or []):
            self._append(suffix)

    def _append(self, text: str):
        for item in _words_to_matcher(_split_to_words(text)):
            if len(item) not in self._suffixes:
                self._suffixes[len(item)] = set()
            self._suffixes[len(item)].add(item)

    def __contains__(self, text: str):
        words = _split_to_words(text)
        for length, tpl_set in self._suffixes.items():
            if length > len(words):
                continue

            seg = [] if length == 0 else words[-length:]
            if _words_to_matcher(seg) & tpl_set:
                return True

        return False


class _PrefixPool:
    def __init__(self, prefixes: Optional[List[str]] = None):
        self._prefixes: Dict[int, Set[_WordTupleTyping]] = {}
        for prefix in (prefixes or []):
            self._append(prefix)

    def _append(self, text: str):
        for item in _words_to_matcher(_split_to_words(text), enable_forms=False):
            if len(item) not in self._prefixes:
                self._prefixes[len(item)] = set()
            self._prefixes[len(item)].add(item)

    def __contains__(self, text: str):
        words = _split_to_words(text)
        for length, tpl_set in self._prefixes.items():
            if length > len(words):
                continue

            seg = words[:length]
            if _words_to_matcher(seg, enable_forms=False) & tpl_set:
                return True

        return False


class CharacterTagPool:
    def __init__(self, whitelist: Optional[List[str]] = None,
                 suffixes: Optional[List[str]] = None,
                 prefixes: Optional[List[str]] = None):
        self._whitelist = _SuffixPool(whitelist or CHAR_WHITELIST)
        self._suffixes = _SuffixPool(suffixes or CHAR_SUFFIXES)
        self._prefixes = _PrefixPool(prefixes or CHAR_PREFIXES)

    def is_basic_character_tag(self, tag: str) -> bool:
        if tag in self._whitelist:
            return False
        else:
            return (tag in self._suffixes) or (tag in self._prefixes)

    def drop_basic_character_tags(self, tags: Union[List[str], Mapping[str, float]]) \
            -> Union[List[str], Mapping[str, float]]:
        if isinstance(tags, dict):
            return {tag: value for tag, value in tags.items() if not self.is_basic_character_tag(tag)}
        elif isinstance(tags, list):
            return [tag for tag in tags if not self.is_basic_character_tag(tag)]
        else:
            raise TypeError(f"Unsupported types of tags, dict or list expected, but {tags!r} found.")


_DEFAULT_CHARACTER_POOL = CharacterTagPool()


def is_basic_character_tag(tag: str) -> bool:
    """
@@ -47,11 +126,7 @@ def is_basic_character_tag(tag: str) -> bool:
        >>> is_basic_character_tag('dress')
        False
    """
    if any(tag_match_full(tag, wl_tag) for wl_tag in _CHAR_WHITELIST):
        return False
    else:
        return (any(tag_match_suffix(tag, suffix) for suffix in _CHAR_SUFFIXES)
                or any(tag_match_prefix(tag, prefix) for prefix in _CHAR_PREFIXES))
    return _DEFAULT_CHARACTER_POOL.is_basic_character_tag(tag)


def drop_basic_character_tags(tags: Union[List[str], Mapping[str, float]]) -> Union[List[str], Mapping[str, float]]:
@@ -78,9 +153,4 @@ def drop_basic_character_tags(tags: Union[List[str], Mapping[str, float]]) -> Un
        ... ])
        ['1girl', 'solo', 'chair', 'hear']
    """
    if isinstance(tags, dict):
        return {tag: value for tag, value in tags.items() if not is_basic_character_tag(tag)}
    elif isinstance(tags, list):
        return [tag for tag in tags if not is_basic_character_tag(tag)]
    else:
        raise TypeError(f"Unsupported types of tags, dict or list expected, but {tags!r} found.")
    return _DEFAULT_CHARACTER_POOL.drop_basic_character_tags(tags)