Commit 7a2b9d43 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add deepdanbooru and tag format tool

parent f0572c40
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
from .deepdanbooru import get_deepdanbooru_tags
from .format import tags_to_text
from .wd14 import get_wd14_tags
+82 −0
Original line number Diff line number Diff line
from functools import lru_cache
from typing import Tuple, List

import numpy as np
import pandas as pd
from PIL import Image
from huggingface_hub import hf_hub_download

from ..data import ImageTyping, load_image
from ..utils import open_onnx_model


@lru_cache()
def _get_deepdanbooru_labels():
    csv_file = hf_hub_download('deepghs/imgutils-models', 'deepdanbooru/deepdanbooru_tags.csv')
    df = pd.read_csv(csv_file)

    tag_names = df["name"].tolist()
    tag_real_names = df['real_name'].tolist()
    rating_indexes = list(np.where(df["category"] == 9)[0])
    general_indexes = list(np.where(df["category"] == 0)[0])
    character_indexes = list(np.where(df["category"] == 4)[0])
    return tag_names, tag_real_names, \
           rating_indexes, general_indexes, character_indexes


@lru_cache()
def _get_deepdanbooru_model():
    return open_onnx_model(hf_hub_download(
        'deepghs/imgutils-models',
        'deepdanbooru/deepdanbooru.onnx',
    ))


def _image_preprocess(image: Image.Image) -> np.ndarray:
    o_width, o_height = image.size
    scale = 512.0 / max(o_width, o_height)
    f_width, f_height = map(lambda x: int(x * scale), (o_width, o_height))
    image = image.resize((f_width, f_height))

    data = np.asarray(image).astype(np.float32) / 255  # H x W x C
    height_pad_left = (512 - f_height) // 2
    height_pad_right = 512 - f_height - height_pad_left
    width_pad_left = (512 - f_width) // 2
    width_pad_right = 512 - f_width - width_pad_left
    data = np.pad(data, ((height_pad_left, height_pad_right), (width_pad_left, width_pad_right), (0, 0)),
                  mode='constant', constant_values=0.0)

    assert data.shape == (512, 512, 3), f'Shape (512, 512, 3) expected, but {data.shape!r} found.'
    return data.reshape((1, 512, 512, 3))  # B x H x W x C


def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
                          general_threshold: float = 0.5, character_threshold: float = 0.5):
    session = _get_deepdanbooru_model()
    _image_data = _image_preprocess(load_image(image, mode='RGB'))

    input_name = session.get_inputs()[0].name
    output_names = [output.name for output in session.get_outputs()]
    probs = session.run(output_names, {input_name: _image_data})[0]

    tag_names, tag_real_names, rating_indexes, general_indexes, character_indexes = _get_deepdanbooru_labels()
    labels: List[Tuple[str, float]] = list(zip(
        tag_real_names if use_real_name else tag_names,
        probs[0].astype(float).tolist(),
    ))

    # First 4 labels are actually ratings: pick one with argmax
    ratings_names = [labels[i] for i in rating_indexes]
    rating = dict(ratings_names)

    # Then we have general tags: pick anywhere prediction confidence > threshold
    general_names = [labels[i] for i in general_indexes]
    general_res = [x for x in general_names if x[1] > general_threshold]
    general_res = dict(general_res)

    # Everything else is characters: pick anywhere prediction confidence > threshold
    character_names = [labels[i] for i in character_indexes]
    character_res = [x for x in character_names if x[1] > character_threshold]
    character_res = dict(character_res)

    return rating, general_res, character_res
+24 −0
Original line number Diff line number Diff line
import re
from typing import Mapping

RE_SPECIAL = re.compile(r'([\\()])')


def tags_to_text(tags: Mapping[str, float],
                 use_spaces: bool = False, use_escape: bool = True,
                 include_ranks: bool = False, score_descend: bool = True) -> str:
    text_items = []
    tags_pairs = tags.items()
    if score_descend:
        tags_pairs = sorted(tags_pairs, key=lambda x: (-x[1], x[0]))
    for tag, score in tags_pairs:
        t_text = tag
        if use_spaces:
            t_text = t_text.replace('_', ' ')
        if use_escape:
            t_text = re.sub(RE_SPECIAL, r'\\\1', t_text)
        if include_ranks:
            t_text = f"({t_text}:{score:.3f})"
        text_items.append(t_text)

    return ', '.join(text_items)
+1 −1
Original line number Diff line number Diff line
@@ -136,7 +136,7 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
    probs = model.run([label_name], {input_name: image})[0]

    tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels()
    labels = list(zip(tag_names, probs[0].astype(float)))
    labels = list(zip(tag_names, probs[0].astype(float).tolist()))

    # First 4 labels are actually ratings: pick one with argmax
    ratings_names = [labels[i] for i in rating_indexes]
+20 −0
Original line number Diff line number Diff line
import pytest

from imgutils.tagging import get_deepdanbooru_tags
from test.testings import get_testfile


@pytest.mark.unittest
class TestTaggingDeepdanbooru:
    def test_get_deepdanbooru_tags(self):
        rating, tags, chars = get_deepdanbooru_tags(get_testfile('6124220.jpg'))
        assert rating['rating:safe'] > 0.9
        assert tags['greyscale'] >= 0.8
        assert tags['pixel_art'] >= 0.9
        assert not chars

        rating, tags, chars = get_deepdanbooru_tags(get_testfile('6125785.jpg'))
        assert rating['rating:safe'] > 0.9
        assert tags['1girl'] >= 0.85
        assert tags['ring'] > 0.8
        assert chars['hu_tao_(genshin_impact)'] >= 0.7
Loading