Commit f1b2ef59 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): merge 2 functions for overlap

parent bfaa8d1f
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -15,5 +15,5 @@ from .format import tags_to_text
from .match import tag_match_suffix, tag_match_prefix, tag_match_full
from .mldanbooru import get_mldanbooru_tags
from .order import sort_tags
from .overlap import drop_overlap_tags, drop_overlaps_for_dict
from .overlap import drop_overlap_tags
from .wd14 import get_wd14_tags
+2 −2
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@ import pandas as pd
from PIL import Image
from huggingface_hub import hf_hub_download

from .overlap import drop_overlaps_for_dict
from .overlap import drop_overlap_tags
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model

@@ -124,7 +124,7 @@ def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
    general_res = [x for x in general_names if x[1] > general_threshold]
    general_res = dict(general_res)
    if drop_overlap:
        general_res = drop_overlaps_for_dict(general_res)
        general_res = drop_overlap_tags(general_res)

    # Everything else is characters: pick anywhere prediction confidence > threshold
    character_names = [labels[i] for i in character_indexes]
+2 −2
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ import pandas as pd
from PIL import Image
from huggingface_hub import hf_hub_download

from .overlap import drop_overlaps_for_dict
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model

@@ -109,5 +109,5 @@ def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False,

    general_tags = {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold}
    if drop_overlap:
        general_tags = drop_overlaps_for_dict(general_tags)
        general_tags = drop_overlap_tags(general_tags)
    return general_tags
+33 −40
Original line number Diff line number Diff line
import copy
import json
from functools import lru_cache
from typing import Mapping, List
from typing import Mapping, List, Union

from huggingface_hub import hf_hub_download

@@ -26,7 +27,7 @@ def _get_overlap_tags() -> Mapping[str, List[str]]:
    return data


def drop_overlap_tags(tags: List[str]) -> List[str]:
def drop_overlap_tags(tags: Union[List[str], Mapping[str, float]]) -> Union[List[str], Mapping[str, float]]:
    """
    Drop overlapping tags from the given list of tags.

@@ -47,13 +48,35 @@ def drop_overlap_tags(tags: List[str]) -> List[str]:
        ... ]
        >>> drop_overlap_tags(tags)
        ['1girl', 'solo', 'very_long_hair', 'red_hair', 'medium_breasts']
        >>>
        >>> tags = {
        ...     '1girl': 0.8849405313291128,
        ...     'solo': 0.8548297594823425,
        ...     'long_hair': 0.03910296474461261,
        ...     'very_long_hair': 0.6615180440330748,
        ...     'red_hair': 0.21552028866308015,
        ...     'breasts': 0.3165260620737027,
        ...     'medium_breasts': 0.47744464927382957,
        ... }
        >>> drop_overlap_tags(tags)
        {
            '1girl': 0.8849405313291128,
            'solo': 0.8548297594823425,
            'very_long_hair': 0.6615180440330748,
            'red_hair': 0.21552028866308015,
            'medium_breasts': 0.47744464927382957
        }
    """
    overlap_tags_dict = _get_overlap_tags()
    result_tags = []
    _origin_tags = copy.deepcopy(tags)
    if isinstance(tags, dict):
        tags = list(tags.keys())
    tags_underscore = [tag.replace(' ', '_') for tag in tags]

    tags: List[str]
    tags_underscore: List[str]
    for tag, tag_ in zip(tags, tags_underscore):

        to_remove = False

        # Case 1: If the tag is a key and some of the associated values are in tags
@@ -71,40 +94,10 @@ def drop_overlap_tags(tags: List[str]) -> List[str]:
        if not to_remove:
            result_tags.append(tag)

    if isinstance(_origin_tags, list):
        return result_tags


def drop_overlaps_for_dict(tags: Mapping[str, float]) -> Mapping[str, float]:
    """
    Drop overlapping tags from the given dictionary of tags with confidence scores.

    This function removes tags that have overlaps with other tags based on precomputed overlap information.

    :param tags: A dictionary where keys are tags and values are confidence scores.
    :type tags: Mapping[str, float]
    :return: A dictionary with non-overlapping tags and their corresponding confidence scores.
    :rtype: Mapping[str, float]

    Examples::
        >>> from imgutils.tagging import drop_overlaps_for_dict
        >>>
        >>> tags = {
        ...     '1girl': 0.8849405313291128,
        ...     'solo': 0.8548297594823425,
        ...     'long_hair': 0.03910296474461261,
        ...     'very_long_hair': 0.6615180440330748,
        ...     'red_hair': 0.21552028866308015,
        ...     'breasts': 0.3165260620737027,
        ...     'medium_breasts': 0.47744464927382957,
        ... }
        >>> drop_overlaps_for_dict(tags)
        {
            '1girl': 0.8849405313291128,
            'solo': 0.8548297594823425,
            'very_long_hair': 0.6615180440330748,
            'red_hair': 0.21552028866308015,
            'medium_breasts': 0.47744464927382957
        }
    """
    key_set = set(drop_overlap_tags(list(tags.keys())))
    return {tag: confidence for tag, confidence in tags.items() if tag in key_set}
    elif isinstance(_origin_tags, dict):
        _rtags_set = set(result_tags)
        return {key: value for key, value in _origin_tags.items() if key in _rtags_set}
    else:
        raise TypeError(f'Unknown tags type - {_origin_tags!r}.')  # pragma: no cover
+2 −2
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ import huggingface_hub
import numpy as np
import pandas as pd

from .overlap import drop_overlaps_for_dict
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model

@@ -152,7 +152,7 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
    general_res = [x for x in general_names if x[1] > general_threshold]
    general_res = dict(general_res)
    if drop_overlap:
        general_res = drop_overlaps_for_dict(general_res)
        general_res = drop_overlap_tags(general_res)

    # Everything else is characters: pick anywhere prediction confidence > threshold
    character_names = [labels[i] for i in character_indexes]
Loading