Commit 8e70591b authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add benchmark for detectiomn

parent 70a1a3ce
Loading
Loading
Loading
Loading
+5 −7
Original line number Diff line number Diff line
@@ -4,13 +4,11 @@ from imgutils.detect.visual import detection_visualize
from imgutils.ocr import ocr
from plot import image_plot

_max_size = 480


def _detect(img, **kwargs):
def _detect(img, *, max_size=None, **kwargs):
    img = load_image(img, mode='RGB', force_background='white')
    if min(img.height, img.width) > _max_size:
        r = _max_size / min(img.height, img.width)
    if max_size is not None and min(img.height, img.width) > max_size:
        r = max_size / min(img.height, img.width)
        img = img.resize((
            int(round(img.width * r)),
            int(round(img.height * r)),
@@ -21,10 +19,10 @@ def _detect(img, **kwargs):

if __name__ == '__main__':
    image_plot(
        (_detect('post_text.jpg', recognize_model='japan_PP-OCRv3_rec'), 'Text of Post'),
        (_detect('post_text.jpg', recognize_model='japan_PP-OCRv3_rec', max_size=480), 'Text of Post'),
        (_detect('anime_subtitle.jpg'), 'Subtitle of Anime'),
        (_detect('comic.jpg'), 'Comic'),
        (_detect('plot.png'), 'Complex'),
        columns=2,
        figsize=(12, 9),
        figsize=(13, 7.5),
    )
+0 −539

File deleted.

Preview size limit exceeded, changes collapsed.

+34 −0
Original line number Diff line number Diff line
import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.ocr.detect import _detect_text, _list_det_models


class OCRDetectBenchmark(BaseBenchmark):
    def __init__(self, model):
        BaseBenchmark.__init__(self)
        self.model = model

    def load(self):
        from imgutils.ocr.detect import _open_ocr_detection_model
        _ = _open_ocr_detection_model(self.model)

    def unload(self):
        from imgutils.ocr.detect import _open_ocr_detection_model
        _open_ocr_detection_model.cache_clear()

    def run(self):
        image_file = random.choice(self.all_images)
        _ = _detect_text(image_file, model=self.model)


if __name__ == '__main__':
    create_plot_cli(
        [
            (model, OCRDetectBenchmark(model))
            for model in _list_det_models()
        ],
        title='Benchmark for OCR Detections',
        run_times=10,
        try_times=20,
    )()
+2 −2
Original line number Diff line number Diff line
@@ -123,7 +123,7 @@ def _normalize(data, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954,
_ALIGN = 64


def _get_text_points(image: ImageTyping, model: str = 'ch_PP-OCRv4_det_infer',
def _get_text_points(image: ImageTyping, model: str = 'ch_PP-OCRv4_det',
                     heat_threshold: float = 0.3, box_threshold: float = 0.7,
                     max_candidates: int = 1000, unclip_ratio: float = 2.0):
    origin_width, origin_height = width, height = image.size
@@ -154,7 +154,7 @@ def _get_text_points(image: ImageTyping, model: str = 'ch_PP-OCRv4_det_infer',
    return retval


def _detect_text(image: ImageTyping, model: str = 'ch_PP-OCRv4_det_infer',
def _detect_text(image: ImageTyping, model: str = 'ch_PP-OCRv4_det',
                 heat_threshold: float = 0.3, box_threshold: float = 0.7,
                 max_candidates: int = 1000, unclip_ratio: float = 2.0):
    image = load_image(image, force_background='white', mode='RGB')
+1 −1
Original line number Diff line number Diff line
@@ -59,7 +59,7 @@ def decode(text_index, model: str, text_prob=None, is_remove_duplicate=False):
    return retval


def _text_recognize(image: ImageTyping, model: str = 'ch_PP-OCRv4_det_infer',
def _text_recognize(image: ImageTyping, model: str = 'ch_PP-OCRv4_rec',
                    is_remove_duplicate: bool = False) -> Tuple[str, float]:
    image = load_image(image, force_background='white', mode='RGB')
    r = 48 / image.height