Commit 08138eb7 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): rerun tagger code

parent bbb4faa0
Loading
Loading
Loading
Loading
+20 −1
Original line number Diff line number Diff line
import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags, get_deepgelbooru_tags


class DeepdanbooruBenchmark(BaseBenchmark):
@@ -18,6 +18,24 @@ class DeepdanbooruBenchmark(BaseBenchmark):
        _ = get_deepdanbooru_tags(image_file)


class DeepgelbooruBenchmark(BaseBenchmark):
    def load(self):
        from imgutils.tagging.deepgelbooru import _open_tags, _open_model, _open_preprocessor
        _ = _open_tags()
        _ = _open_model()
        _ = _open_preprocessor

    def unload(self):
        from imgutils.tagging.deepgelbooru import _open_tags, _open_model, _open_preprocessor
        _open_tags.cache_clear()
        _open_model.cache_clear()
        _open_preprocessor.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = get_deepgelbooru_tags(image_file)


class Wd14Benchmark(BaseBenchmark):
    def __init__(self, model):
        BaseBenchmark.__init__(self)
@@ -54,6 +72,7 @@ if __name__ == '__main__':
    create_plot_cli(
        [
            ('deepdanbooru', DeepdanbooruBenchmark()),
            ('deepgelbooru', DeepgelbooruBenchmark()),
            ('wd14-swinv2', Wd14Benchmark("SwinV2")),
            ('wd14-convnext', Wd14Benchmark("ConvNext")),
            ('wd14-convnextv2', Wd14Benchmark("ConvNextV2")),
+0 −3077

File deleted.

Preview size limit exceeded, changes collapsed.

+1 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ Overview:
from .blacklist import is_blacklisted, drop_blacklisted_tags
from .character import is_basic_character_tag, drop_basic_character_tags
from .deepdanbooru import get_deepdanbooru_tags
from .deepgelbooru import get_deepgelbooru_tags
from .format import tags_to_text, add_underline, remove_underline
from .match import tag_match_suffix, tag_match_prefix, tag_match_full
from .mldanbooru import get_mldanbooru_tags
+43 −1
Original line number Diff line number Diff line
import json

import numpy as np
import pandas as pd
from PIL import Image
from huggingface_hub import hf_hub_download

from .overlap import drop_overlap_tags
from ..data import ImageTyping, load_image
from ..preprocess import create_pillow_transforms
from ..utils import ts_lru_cache, open_onnx_model
from ..utils import ts_lru_cache, open_onnx_model, vreplace

_REPO_ID = 'deepghs/deepgelbooru_onnx'

@@ -37,3 +41,41 @@ def _open_tags():
    ))
    return {item['tag_id']: item for item in df_tags.to_dict('records')}


def _image_preprocess(image: Image.Image):
    return _open_preprocessor()(image).transpose((1, 2, 0))[None, ...].astype(np.float32)


def get_deepgelbooru_tags(image: ImageTyping,
                          general_threshold: float = 0.3, character_threshold: float = 0.3,
                          drop_overlap: bool = False, fmt=('rating', 'general', 'character')):
    input_ = _image_preprocess(load_image(image, mode='RGB'))
    session = _open_model()
    prediction, = session.run(['prediction'], {'input': input_})
    prediction = prediction[0]

    d_tags = _open_tags()
    d_general, d_characters, d_rating = {}, {}, {}
    for idx, score in enumerate(prediction.tolist()):
        tag_info = d_tags[idx]
        category = tag_info['category']
        if category == 0:
            if score >= general_threshold:
                d_general[tag_info['name']] = score
        elif category == 4:
            if score >= character_threshold:
                d_characters[tag_info['name']] = score
        elif category == 9:
            d_rating[tag_info['name']] = score
        else:
            assert False, 'Should not reach this line.'  # pragma: no cover

    if drop_overlap:
        d_general = drop_overlap_tags(d_general)

    return vreplace(fmt, {
        'general': d_general,
        'character': d_characters,
        'rating': d_rating,
        'prediction': prediction,
    })