Commit e7d67cd5 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add benchmark

parent 744b7a72
Loading
Loading
Loading
Loading
+28 −1
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags, get_deepgelbooru_tags, \
    get_camie_tags
    get_camie_tags, get_pixai_tags


class DeepdanbooruBenchmark(BaseBenchmark):
@@ -95,6 +95,32 @@ class CamieBenchmark(BaseBenchmark):
        _ = get_camie_tags(image_file, model_name=self.model_name)


class PixAITaggerBenchmark(BaseBenchmark):
    def __init__(self, model_name: str):
        BaseBenchmark.__init__(self)
        self.model_name = model_name

    def load(self):
        from imgutils.tagging.pixai import _open_tags, _open_preprocess, _open_onnx_model, \
            _open_default_category_thresholds
        _ = _open_tags(self.model_name)
        _ = _open_preprocess(self.model_name)
        _ = _open_onnx_model(self.model_name)
        _ = _open_default_category_thresholds(self.model_name)

    def unload(self):
        from imgutils.tagging.pixai import _open_tags, _open_preprocess, _open_onnx_model, \
            _open_default_category_thresholds
        _open_tags.cache_clear()
        _open_preprocess.cache_clear()
        _open_onnx_model.cache_clear()
        _open_default_category_thresholds.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = get_pixai_tags(image_file, model_name=self.model_name)


if __name__ == '__main__':
    create_plot_cli(
        [
@@ -113,6 +139,7 @@ if __name__ == '__main__':
            ('mldanbooru', MLDanbooruBenchmark()),
            ('camie-initial', CamieBenchmark('initial')),
            ('camie-refined', CamieBenchmark('refined')),
            ('pixai-tagger-v0.9', PixAITaggerBenchmark('v0.9')),
        ],
        title='Benchmark for Tagging Models',
        run_times=10,