Commit 42d93fb7 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add benchmark for tagging models

parent 364fae47
Loading
Loading
Loading
Loading
+131 −0
Original line number Diff line number Diff line
import glob
import multiprocessing
import os
import time
import warnings
from multiprocessing import Process
from typing import Tuple, List

import matplotlib.pyplot as plt
import numpy as np
import psutil
from hbutils.scale import size_to_bytes_str
from hbutils.string import ordinalize, plural_word
from matplotlib.ticker import FuncFormatter
from tqdm.auto import tqdm

from conf import PROJ_DIR
from plot import INCHES_TO_PIXELS

_DEFAULT_IMAGE_POOL = glob.glob(os.path.join(PROJ_DIR, 'test', 'testfile', 'dataset', '**', '*.jpg'), recursive=True)


class BaseBenchmark:
    def __init__(self):
        self.all_images = _DEFAULT_IMAGE_POOL

    def load(self):
        raise NotImplementedError

    def unload(self):
        raise NotImplementedError

    def run(self):
        raise NotImplementedError

    def run_benchmark(self, run_times):
        logs = []
        current_process = psutil.Process()

        def _record(name):
            logs.append((name, current_process.memory_info().rss, time.time()))

        # make sure the model is downloaded
        self.load()
        self.unload()

        _record('<init>')

        self.load()
        _record('<load>')

        for i in tqdm(range(run_times)):
            self.run()
            _record(f'#{i + 1}')

        self.unload()
        _record('<unload>')

        mems = np.array([mem for _, mem, _ in logs])
        mems -= mems[0]
        times = np.array([time_ for _, _, time_ in logs])
        times -= times[0]
        times[1:] = times[1:] - times[:-1]
        labels = np.array([name for name, _, _ in logs])

        return mems, times, labels

    def _run_in_subprocess_share(self, run_times, ret):
        ret['retval'] = self.run_benchmark(run_times)

    def run_in_subprocess(self, run_times: int = 10, try_times: int = 10):
        manager = multiprocessing.Manager()
        full_deltas, full_times, final_labels = [], [], None
        for i in tqdm(range(try_times)):
            ret = manager.dict()
            p = Process(target=self._run_in_subprocess_share, args=(run_times, ret,))
            p.start()
            p.join()
            if p.exitcode != 0:
                raise ChildProcessError(f'Exitcode {p.exitcode} in {self!r}\'s {ordinalize(i + 1)} try.')

            mems, times, labels = ret['retval']
            deltas = mems[1:] - mems[:-1]
            full_deltas.append(deltas)
            full_times.append(times)
            if final_labels is None:
                final_labels = labels

        deltas = np.stack(full_deltas).mean(axis=0)
        final_mems = np.cumsum([0, *deltas])
        final_times = np.stack(full_times).mean(axis=0)

        return final_mems, final_times, final_labels


def create_plot(items: List[Tuple[str, BaseBenchmark]], save_as: str,
                title: str = 'Unnamed Benchmark Plot', run_times=15, try_times=10, figsize=(720, 420), dpi: int = 300):
    def fmt_size(x, pos):
        _ = pos
        warnings.filterwarnings('ignore')
        return size_to_bytes_str(x, precision=1)

    fig, axes = plt.subplots(1, 2, figsize=(figsize[0] / INCHES_TO_PIXELS, figsize[1] / INCHES_TO_PIXELS))

    axes[0].yaxis.set_major_formatter(FuncFormatter(fmt_size))
    axes[0].set_title('Memory Benchmark')
    axes[0].set_ylabel('Memory Usage')
    axes[1].set_title('Performance Benchmark (CPU)')
    axes[1].set_ylabel('Time Cost (s)')
    labeled = False

    for name, bm in tqdm(items):
        mems, times, labels = bm.run_in_subprocess(run_times, try_times)
        axes[0].plot(mems, label=name)
        axes[1].plot(times, label=name)
        if not labeled:
            axes[0].set_xticks(range(len(labels)), labels, rotation='vertical')
            axes[1].set_xticks(range(len(labels)), labels, rotation='vertical')
            labeled = True

    axes[0].legend()
    axes[0].grid()
    axes[1].legend()
    axes[1].grid()

    fig.suptitle(f'{title}\n'
                 f'(Mean of {plural_word(try_times, "try")}, '
                 f'run for {plural_word(run_times, "time")})')

    fig.tight_layout()
    plt.savefig(save_as, bbox_inches='tight', dpi=dpi, transparent=True)
+3 −0
Original line number Diff line number Diff line
import os

PROJ_DIR = os.path.normpath(os.path.join(os.environ.get('PROJ_DIR'), '..'))
+53 −0
Original line number Diff line number Diff line
import random

from benchmark import BaseBenchmark, create_plot
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags


class DeepdanbooruBenchmark(BaseBenchmark):
    def load(self):
        from imgutils.tagging.deepdanbooru import _get_deepdanbooru_model
        _ = _get_deepdanbooru_model()

    def unload(self):
        from imgutils.tagging.deepdanbooru import _get_deepdanbooru_model
        _get_deepdanbooru_model.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = get_deepdanbooru_tags(image_file)


class Wd14Benchmark(BaseBenchmark):
    def __init__(self, model):
        BaseBenchmark.__init__(self)
        self.model = model

    def load(self):
        from imgutils.tagging.wd14 import _get_wd14_model
        _ = _get_wd14_model(self.model)

    def unload(self):
        from imgutils.tagging.wd14 import _get_wd14_model
        _get_wd14_model.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = get_wd14_tags(image_file, model_name=self.model)


if __name__ == '__main__':
    create_plot(
        [
            ('deepdanbooru', DeepdanbooruBenchmark()),
            ('wd14-swinv2', Wd14Benchmark("SwinV2")),
            ('wd14-convnext', Wd14Benchmark("ConvNext")),
            ('wd14-convnextv2', Wd14Benchmark("ConvNextV2")),
            ('wd14-vit', Wd14Benchmark("ViT")),
        ],
        save_as='benchmark_tagging.dat.svg',
        title='Benchmark for Tagging Models',
        run_times=10,
        try_times=5,
        figsize=(1080, 600)
    )
+6 −0
Original line number Diff line number Diff line
"""
Overview:
    Get tags for anime images.

    This is an overall benchmark of all the danbooru models:

    .. image:: benchmark_tagging.dat.svg
        :align: center

"""
from .deepdanbooru import get_deepdanbooru_tags
from .format import tags_to_text