Commit f79a2bf1 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add unittest and better docs

parent 2d94ec3b
Loading
Loading
Loading
Loading
+21 −0
Original line number Diff line number Diff line
"""
Overview:
    Detect and recognize text in images.

    The models are exported from `PaddleOCR <https://github.com/PaddlePaddle/PaddleOCR>`_, hosted on
    `huggingface - deepghs/paddleocr <https://huggingface.co/deepghs/paddleocr/tree/main>`_.

    .. image:: ocr_demo.plot.py.svg
        :align: center

    This is an overall benchmark of all the text detection models:

    .. image:: ocr_det_benchmark.plot.py.svg
        :align: center

    and an overall benchmark of all the available text recognition models:

    .. image:: ocr_rec_benchmark.plot.py.svg
        :align: center

"""
from .entry import detect_text_with_ocr, ocr, list_det_models, list_rec_models
+0 −3
Original line number Diff line number Diff line
@@ -23,9 +23,6 @@ def _open_ocr_detection_model(model):
        f'det/{model}/model.onnx',
    ))

    print(ort.get_inputs()[0].shape)
    return ort


def _box_score_fast(bitmap, _box):
    h, w = bitmap.shape[:2]
+56 −4
Original line number Diff line number Diff line
@@ -3,17 +3,28 @@ from typing import List, Tuple
from .detect import _detect_text, _list_det_models
from .recognize import _text_recognize, _list_rec_models
from ..data import ImageTyping, load_image
from ..utils import tqdm

_DEFAULT_DET_MODEL = 'ch_PP-OCRv4_det'
_DEFAULT_REC_MODEL = 'ch_PP-OCRv4_rec'


def list_det_models() -> List[str]:
    """
    List available text detection models for OCR.

    :return: A list of available text detection model names.
    :rtype: List[str]
    """
    return _list_det_models()


def list_rec_models() -> List[str]:
    """
    List available text recognition models for OCR.

    :return: A list of available text recognition model names.
    :rtype: List[str]
    """
    return _list_rec_models()


@@ -21,6 +32,24 @@ def detect_text_with_ocr(image: ImageTyping, model: str = _DEFAULT_DET_MODEL,
                         heat_threshold: float = 0.3, box_threshold: float = 0.7,
                         max_candidates: int = 1000, unclip_ratio: float = 2.0) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
    Detect text in an image using an OCR model.

    :param image: The input image.
    :type image: ImageTyping
    :param model: The name of the text detection model.
    :type model: str, optional
    :param heat_threshold: The heat map threshold for text detection.
    :type heat_threshold: float, optional
    :param box_threshold: The box threshold for text detection.
    :type box_threshold: float, optional
    :param max_candidates: The maximum number of candidates to consider.
    :type max_candidates: int, optional
    :param unclip_ratio: The unclip ratio for text detection.
    :type unclip_ratio: float, optional
    :return: A list of detected text boxes, their corresponding text content, and their confidence scores.
    :rtype: List[Tuple[Tuple[int, int, int, int], str, float]]
    """
    retval = []
    for box, _, score in _detect_text(image, model, heat_threshold, box_threshold, max_candidates, unclip_ratio):
        retval.append((box, 'text', score))
@@ -31,12 +60,35 @@ def detect_text_with_ocr(image: ImageTyping, model: str = _DEFAULT_DET_MODEL,
def ocr(image: ImageTyping, detect_model: str = _DEFAULT_DET_MODEL,
        recognize_model: str = _DEFAULT_REC_MODEL, heat_threshold: float = 0.3, box_threshold: float = 0.7,
        max_candidates: int = 1000, unclip_ratio: float = 2.0, rotation_threshold: float = 1.5,
        is_remove_duplicate: bool = False, silent: bool = False):
        is_remove_duplicate: bool = False):
    """
    Perform optical character recognition (OCR) on an image.

    :param image: The input image.
    :type image: ImageTyping
    :param detect_model: The name of the text detection model.
    :type detect_model: str, optional
    :param recognize_model: The name of the text recognition model.
    :type recognize_model: str, optional
    :param heat_threshold: The heat map threshold for text detection.
    :type heat_threshold: float, optional
    :param box_threshold: The box threshold for text detection.
    :type box_threshold: float, optional
    :param max_candidates: The maximum number of candidates to consider.
    :type max_candidates: int, optional
    :param unclip_ratio: The unclip ratio for text detection.
    :type unclip_ratio: float, optional
    :param rotation_threshold: The rotation threshold for text detection.
    :type rotation_threshold: float, optional
    :param is_remove_duplicate: Whether to remove duplicate text content.
    :type is_remove_duplicate: bool, optional
    :return: A list of detected text boxes, their corresponding text content, and their combined confidence scores.
    :rtype: List[Tuple[Tuple[int, int, int, int], str, float]]
    """
    image = load_image(image)
    retval = []
    for (x0, y0, x1, y1), _, score in \
            tqdm(_detect_text(image, detect_model, heat_threshold,
                              box_threshold, max_candidates, unclip_ratio), silent=silent):
            _detect_text(image, detect_model, heat_threshold, box_threshold, max_candidates, unclip_ratio):
        width, height = x1 - x0, y1 - y0
        area = image.crop((x0, y0, x1, y1))
        if height >= width * rotation_threshold:
+2 −2
Original line number Diff line number Diff line
@@ -31,7 +31,7 @@ def _open_ocr_recognition_dictionary(model) -> List[str]:
    return ['<blank>', *dict_, ' ']


def decode(text_index, model: str, text_prob=None, is_remove_duplicate=False):
def _text_decode(text_index, model: str, text_prob=None, is_remove_duplicate=False):
    retval = []
    ignored_tokens = [0]
    batch_size = len(text_index)
@@ -76,7 +76,7 @@ def _text_recognize(image: ImageTyping, model: str = 'ch_PP-OCRv4_rec',

    indices = output.argmax(axis=2)
    confs = output.max(axis=2)
    return decode(indices, model, confs, is_remove_duplicate)[0]
    return _text_decode(indices, model, confs, is_remove_duplicate)[0]


@lru_cache()

test/ocr/__init__.py

0 → 100644
+0 −0

Empty file added.

Loading