Commit e7b8cd15 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): try fix unittest

parent e1ff174d
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -34,6 +34,9 @@ class BaseBenchmark:
    def unload(self):
        raise NotImplementedError

    def after_unload(self):
        pass

    def run(self):
        raise NotImplementedError

@@ -60,6 +63,7 @@ class BaseBenchmark:

        self.unload()
        _record('<unload>')
        self.after_unload()

        mems = np.array([mem for _, mem, _ in logs])
        mems -= mems[0]
+25 −12
Original line number Diff line number Diff line
import random

from hfutils.cache import delete_cache

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags, get_deepgelbooru_tags, \
    get_camie_tags, get_pixai_tags


class DeepdanbooruBenchmark(BaseBenchmark):
class CleanModelStorageBenchmark(BaseBenchmark):
    def after_unload(self):
        delete_cache()


class DeepdanbooruBenchmark(CleanModelStorageBenchmark):
    def load(self):
        from imgutils.tagging.deepdanbooru import _get_deepdanbooru_model
        from imgutils.tagging.deepdanbooru import _get_deepdanbooru_model, _get_deepdanbooru_labels
        _ = _get_deepdanbooru_model()
        _ = _get_deepdanbooru_labels

    def unload(self):
        from imgutils.tagging.deepdanbooru import _get_deepdanbooru_model
        from imgutils.tagging.deepdanbooru import _get_deepdanbooru_model, _get_deepdanbooru_labels
        _get_deepdanbooru_model.cache_clear()
        _get_deepdanbooru_labels.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = get_deepdanbooru_tags(image_file)


class DeepgelbooruBenchmark(BaseBenchmark):
class DeepgelbooruBenchmark(CleanModelStorageBenchmark):
    def load(self):
        from imgutils.tagging.deepgelbooru import _open_tags, _open_model, _open_preprocessor
        _ = _open_tags()
@@ -37,39 +46,43 @@ class DeepgelbooruBenchmark(BaseBenchmark):
        _ = get_deepgelbooru_tags(image_file)


class Wd14Benchmark(BaseBenchmark):
class Wd14Benchmark(CleanModelStorageBenchmark):
    def __init__(self, model):
        BaseBenchmark.__init__(self)
        self.model = model

    def load(self):
        from imgutils.tagging.wd14 import _get_wd14_model
        from imgutils.tagging.wd14 import _get_wd14_model, _get_wd14_labels
        _ = _get_wd14_model(self.model)
        _ = _get_wd14_labels(self.model)

    def unload(self):
        from imgutils.tagging.wd14 import _get_wd14_model
        from imgutils.tagging.wd14 import _get_wd14_model, _get_wd14_labels
        _get_wd14_model.cache_clear()
        _get_wd14_labels.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = get_wd14_tags(image_file, model_name=self.model)


class MLDanbooruBenchmark(BaseBenchmark):
class MLDanbooruBenchmark(CleanModelStorageBenchmark):
    def load(self):
        from imgutils.tagging.mldanbooru import _open_mldanbooru_model
        from imgutils.tagging.mldanbooru import _open_mldanbooru_model, _get_mldanbooru_labels
        _ = _open_mldanbooru_model()
        _ = _get_mldanbooru_labels()

    def unload(self):
        from imgutils.tagging.mldanbooru import _open_mldanbooru_model
        from imgutils.tagging.mldanbooru import _open_mldanbooru_model, _get_mldanbooru_labels
        _open_mldanbooru_model.cache_clear()
        _get_mldanbooru_labels.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = get_mldanbooru_tags(image_file)


class CamieBenchmark(BaseBenchmark):
class CamieBenchmark(CleanModelStorageBenchmark):
    def __init__(self, model_name):
        BaseBenchmark.__init__(self)
        self.model_name = model_name
@@ -95,7 +108,7 @@ class CamieBenchmark(BaseBenchmark):
        _ = get_camie_tags(image_file, model_name=self.model_name)


class PixAITaggerBenchmark(BaseBenchmark):
class PixAITaggerBenchmark(CleanModelStorageBenchmark):
    def __init__(self, model_name: str):
        BaseBenchmark.__init__(self)
        self.model_name = model_name
+5 −2
Original line number Diff line number Diff line
@@ -22,5 +22,8 @@ def text_aligner():

@pytest.fixture(autouse=True, scope='module')
def clean_hf_cache():
    try:
        yield
    finally:
        if os.environ.get('CI'):
            delete_cache()