Unverified Commit ece3de24 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #13 from deepghs/doc/metrics

dev(narugo): add docs for metrics lpips
parents df82a909 a89c870d
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -25,6 +25,9 @@ class BaseBenchmark:
    def __init__(self):
        self.all_images = _DEFAULT_IMAGE_POOL

    def prepare(self):
        pass

    def load(self):
        raise NotImplementedError

@@ -42,6 +45,7 @@ class BaseBenchmark:
            logs.append((name, current_process.memory_info().rss, time.time()))

        # make sure the model is downloaded
        self.prepare()
        self.load()
        self.unload()

+48 −0
Original line number Diff line number Diff line
import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.metrics import lpips_extract_feature, lpips_difference


class LPIPSFeatureBenchmark(BaseBenchmark):
    def load(self):
        from imgutils.metrics.lpips import _lpips_feature_model
        _ = _lpips_feature_model()

    def unload(self):
        from imgutils.metrics.lpips import _lpips_feature_model
        _lpips_feature_model.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = lpips_extract_feature(image_file)


class LPIPSDiffBenchmark(BaseBenchmark):
    def prepare(self):
        self.feats = [lpips_extract_feature(img) for img in random.sample(self.all_images, k=30)]

    def load(self):
        from imgutils.metrics.lpips import _lpips_diff_model
        _ = _lpips_diff_model()

    def unload(self):
        from imgutils.metrics.lpips import _lpips_diff_model
        _lpips_diff_model.cache_clear()

    def run(self):
        feat1 = random.choice(self.feats)
        feat2 = random.choice(self.feats)
        _ = lpips_difference(feat1, feat2)


if __name__ == '__main__':
    create_plot_cli(
        [
            ('feature extract', LPIPSFeatureBenchmark()),
            ('diff calculate', LPIPSDiffBenchmark()),
        ],
        title='Benchmark for LPIPS Models',
        run_times=10,
        try_times=20,
    )()
+2262 −0

File added.

Preview size limit exceeded, changes collapsed.

+5 −0
Original line number Diff line number Diff line
@@ -5,6 +5,11 @@ Overview:

    When threshold is `0.45`, the `adjusted rand score <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html>`_ can reach `0.995`.

    This is an overall benchmark of all the operations in LPIPS models:

    .. image:: lpips.benchmark.py.svg
        :align: center

"""
from functools import lru_cache
from typing import Tuple, Union, List