Unverified Commit 0732f1a8 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #158 from deepghs/dev/camie

dev(narugo): add support for camie taggers
parents 6a202a9b 747d6bf2
Loading
Loading
Loading
Loading
+21 −0
Original line number Diff line number Diff line
imgutils.tagging.camie
====================================

.. currentmodule:: imgutils.tagging.camie

.. automodule:: imgutils.tagging.camie


get_camie_tags
----------------------

.. autofunction:: get_camie_tags



convert_camie_emb_to_prediction
----------------------------------------------------

.. autofunction:: convert_camie_emb_to_prediction

+1 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ imgutils.tagging

    mldanbooru
    wd14
    camie
    deepdanbooru
    deepgelbooru
    format
+30 −1
Original line number Diff line number Diff line
import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags, get_deepgelbooru_tags
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags, get_deepgelbooru_tags, \
    get_camie_tags


class DeepdanbooruBenchmark(BaseBenchmark):
@@ -68,6 +69,32 @@ class MLDanbooruBenchmark(BaseBenchmark):
        _ = get_mldanbooru_tags(image_file)


class CamieBenchmark(BaseBenchmark):
    def __init__(self, model_name):
        BaseBenchmark.__init__(self)
        self.model_name = model_name

    def load(self):
        from imgutils.tagging.camie import _get_camie_model, _get_camie_labels, _get_camie_threshold, \
            _get_camie_preprocessor
        _ = _get_camie_model(self.model_name)
        _ = _get_camie_labels(self.model_name)
        _ = _get_camie_threshold(self.model_name)
        _ = _get_camie_preprocessor(self.model_name)

    def unload(self):
        from imgutils.tagging.camie import _get_camie_model, _get_camie_labels, _get_camie_threshold, \
            _get_camie_preprocessor
        _get_camie_model.cache_clear()
        _get_camie_labels.cache_clear()
        _get_camie_threshold.cache_clear()
        _get_camie_preprocessor.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = get_camie_tags(image_file, model_name=self.model_name)


if __name__ == '__main__':
    create_plot_cli(
        [
@@ -84,6 +111,8 @@ if __name__ == '__main__':
            ('wd-vit-large-tagger-v3', Wd14Benchmark("ViT_Large")),
            ('wd-eva02-large-tagger-v3', Wd14Benchmark("EVA02_Large")),
            ('mldanbooru', MLDanbooruBenchmark()),
            ('camie-initial', CamieBenchmark('initial')),
            ('camie-refined', CamieBenchmark('refined')),
        ],
        title='Benchmark for Tagging Models',
        run_times=10,
+1022 −862

File changed.

Preview size limit exceeded, changes collapsed.

+1 −0
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ Overview:

"""
from .blacklist import is_blacklisted, drop_blacklisted_tags
from .camie import get_camie_tags, convert_camie_emb_to_prediction
from .character import is_basic_character_tag, drop_basic_character_tags
from .deepdanbooru import get_deepdanbooru_tags
from .deepgelbooru import get_deepgelbooru_tags
Loading