Commit 6ca8bafe authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix unittests

parent f79a2bf1
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
import os.path
from functools import lru_cache
from typing import List

@@ -167,6 +166,7 @@ def _detect_text(image: ImageTyping, model: str = 'ch_PP-OCRv4_det',
@lru_cache()
def _list_det_models() -> List[str]:
    retval = []
    for item in _HF_CLIENT.glob(f'{_REPOSITORY}/det/*/model.onnx', ):
        retval.append(os.path.relpath(item, _REPOSITORY).split('/')[1])
    repo_segment_cnt = len(_REPOSITORY.split('/'))
    for item in _HF_CLIENT.glob(f'{_REPOSITORY}/det/*/model.onnx'):
        retval.append(item.split('/')[repo_segment_cnt:][1])
    return retval
+4 −3
Original line number Diff line number Diff line
@@ -25,7 +25,7 @@ def _open_ocr_recognition_dictionary(model) -> List[str]:
    with open(hf_hub_download(
            _REPOSITORY,
            f'rec/{model}/dict.txt',
    ), 'r') as f:
    ), 'r', encoding='utf-8') as f:
        dict_ = [line.strip() for line in f]

    return ['<blank>', *dict_, ' ']
@@ -82,6 +82,7 @@ def _text_recognize(image: ImageTyping, model: str = 'ch_PP-OCRv4_rec',
@lru_cache()
def _list_rec_models() -> List[str]:
    retval = []
    for item in _HF_CLIENT.glob(f'{_REPOSITORY}/rec/*/model.onnx', ):
        retval.append(os.path.relpath(item, _REPOSITORY).split('/')[1])
    repo_segment_cnt = len(_REPOSITORY.split('/'))
    for item in _HF_CLIENT.glob(f'{_REPOSITORY}/rec/*/model.onnx'):
        retval.append(item.split('/')[repo_segment_cnt:][1])
    return retval
+28 −9
Original line number Diff line number Diff line
@@ -91,15 +91,34 @@ class TestOcr:
        detections = ocr(ocr_img_comic)
        assert len(detections) == 8

        assert detections == pytest.approx([
            ((742, 485, 809, 511), 'MOB.', 0.9356705927336156),
            ((716, 136, 836, 164), 'SHISHOU,', 0.8933000384412466),
            ((682, 98, 734, 124), 'BUT', 0.8730931912907247),
            ((144, 455, 196, 485), 'OH,', 0.8417627579351514),
            ((427, 129, 553, 154), 'A MIRROR.', 0.7366019454049503),
            ((1030, 557, 1184, 578), '(EL)  GATO IBERICO', 0.7271127306351021),
            ((719, 455, 835, 488), "THAt'S △", 0.701928390168364),
            ((124, 478, 214, 508), 'LOOK!', 0.6965972578194936),
        bboxes = []
        texts = []
        scores = []
        for bbox, text, score in detections:
            bboxes.append(bbox)
            texts.append(text)
            scores.append(score)

        assert bboxes == pytest.approx([
            (742, 485, 809, 511),
            (716, 136, 836, 164),
            (682, 98, 734, 124),
            (144, 455, 196, 485),
            (427, 129, 553, 154),
            (1030, 557, 1184, 578),
            (719, 455, 835, 488),
            (124, 478, 214, 508),
        ])
        assert texts == ['MOB.', 'SHISHOU,', 'BUT', 'OH,', 'A MIRROR.', '(EL)  GATO IBERICO', "THAt'S △", 'LOOK!']
        assert scores == pytest.approx([
            0.9356677655964869,
            0.8932994278321376,
            0.8730925493136663,
            0.8417598172118067,
            0.7365999885917329,
            0.7271122893745091,
            0.7019268051682541,
            0.6965953319577997
        ], abs=1e-3)

    def test_ocr_plot(self, ocr_img_plot):