Commit 2c7cf25d authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add overlap dropping for tags

parent 17df867e
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -13,3 +13,4 @@ imgutils.tagging
    wd14
    deepdanbooru
    format
    overlap
+22 −0
Original line number Diff line number Diff line
imgutils.tagging.overlap
====================================

.. currentmodule:: imgutils.tagging.overlap

.. automodule:: imgutils.tagging.overlap


drop_overlap_tags
----------------------------------

.. autofunction:: drop_overlap_tags



drop_overlaps_for_dict
----------------------------------

.. autofunction:: drop_overlaps_for_dict


+1 −0
Original line number Diff line number Diff line
@@ -11,4 +11,5 @@ Overview:
from .deepdanbooru import get_deepdanbooru_tags
from .format import tags_to_text
from .mldanbooru import get_mldanbooru_tags
from .overlap import drop_overlap_tags, drop_overlaps_for_dict
from .wd14 import get_wd14_tags
+6 −2
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ import pandas as pd
from PIL import Image
from huggingface_hub import hf_hub_download

from .overlap import drop_overlaps_for_dict
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model

@@ -61,7 +62,8 @@ def _image_preprocess(image: Image.Image) -> np.ndarray:


def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
                          general_threshold: float = 0.5, character_threshold: float = 0.5):
                          general_threshold: float = 0.5, character_threshold: float = 0.5,
                          drop_overlap: bool = False):
    """
    Overview:
        Get tags for anime image based on ``deepdanbooru`` model.
@@ -120,6 +122,8 @@ def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
    general_names = [labels[i] for i in general_indexes]
    general_res = [x for x in general_names if x[1] > general_threshold]
    general_res = dict(general_res)
    if drop_overlap:
        general_res = drop_overlaps_for_dict(general_res)

    # Everything else is characters: pick anywhere prediction confidence > threshold
    character_names = [labels[i] for i in character_indexes]
+8 −2
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ import pandas as pd
from PIL import Image
from huggingface_hub import hf_hub_download

from .overlap import drop_overlaps_for_dict
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model

@@ -57,7 +58,8 @@ def _get_mldanbooru_labels(use_real_name: bool = False) -> Tuple[List[str], List


def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False,
                        threshold: float = 0.7, size: int = 448, keep_ratio: bool = False):
                        threshold: float = 0.7, size: int = 448, keep_ratio: bool = False,
                        drop_overlap: bool = False):
    """
    Overview:
        Tagging image with ML-Danbooru, similar to
@@ -103,4 +105,8 @@ def get_mldanbooru_tags(image: ImageTyping, use_real_name: bool = False,
    output = (1 / (1 + np.exp(-native_output))).reshape(-1)
    tags = _get_mldanbooru_labels(use_real_name)
    pairs = sorted([(tags[i], ratio) for i, ratio in enumerate(output)], key=lambda x: (-x[1], x[0]))
    return {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold}

    general_tags = {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold}
    if drop_overlap:
        general_tags = drop_overlaps_for_dict(general_tags)
    return general_tags
Loading