Commit eb024208 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): regenerate taggers benchmark

parent c545587f
Loading
Loading
Loading
Loading
+30 −1
Original line number Diff line number Diff line
import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags, get_deepgelbooru_tags
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags, get_deepgelbooru_tags, \
    get_camie_tags


class DeepdanbooruBenchmark(BaseBenchmark):
@@ -68,6 +69,32 @@ class MLDanbooruBenchmark(BaseBenchmark):
        _ = get_mldanbooru_tags(image_file)


class CamieBenchmark(BaseBenchmark):
    def __init__(self, model_name):
        BaseBenchmark.__init__(self)
        self.model_name = model_name

    def load(self):
        from imgutils.tagging.camie import _get_camie_model, _get_camie_labels, _get_camie_threshold, \
            _get_camie_preprocessor
        _ = _get_camie_model(self.model_name)
        _ = _get_camie_labels(self.model_name)
        _ = _get_camie_threshold(self.model_name)
        _ = _get_camie_preprocessor(self.model_name)

    def unload(self):
        from imgutils.tagging.camie import _get_camie_model, _get_camie_labels, _get_camie_threshold, \
            _get_camie_preprocessor
        _get_camie_model.cache_clear()
        _get_camie_labels.cache_clear()
        _get_camie_threshold.cache_clear()
        _get_camie_preprocessor.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = get_camie_tags(image_file, model_name=self.model_name)


if __name__ == '__main__':
    create_plot_cli(
        [
@@ -84,6 +111,8 @@ if __name__ == '__main__':
            ('wd-vit-large-tagger-v3', Wd14Benchmark("ViT_Large")),
            ('wd-eva02-large-tagger-v3', Wd14Benchmark("EVA02_Large")),
            ('mldanbooru', MLDanbooruBenchmark()),
            ('camie-initial', CamieBenchmark('initial')),
            ('camie-refined', CamieBenchmark('refined')),
        ],
        title='Benchmark for Tagging Models',
        run_times=10,
+0 −3155

File deleted.

Preview size limit exceeded, changes collapsed.