Commit e79705a6 authored by dmMaze's avatar dmMaze
Browse files

Auto split textlines for mit ocr models when ocr manually (#483, #110)

parent 93d9c5e6
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ class OCRBase(BaseModule):
                self.name = key
                break

    def run_ocr(self, img: np.ndarray, blk_list: List[TextBlock] = None) -> Union[List[TextBlock], str]:
    def run_ocr(self, img: np.ndarray, blk_list: List[TextBlock] = None, *args, **kwargs) -> Union[List[TextBlock], str]:

        if not self.all_model_loaded():
            self.load_model()
@@ -39,7 +39,7 @@ class OCRBase(BaseModule):
            if self.name != 'none_ocr':
                blk.text = []
                
        self._ocr_blk_list(img, blk_list)
        self._ocr_blk_list(img, blk_list, *args, **kwargs)
        for callback_name, callback in self._postprocess_hooks.items():
            callback(textblocks=blk_list, img=img, ocr_module=self)

+2 −12
Original line number Diff line number Diff line
@@ -544,6 +544,7 @@ class OCR32pxModel:
    def __init__(self, model_path, device='cpu') -> None:
        self.device = device
        self.text_height = 32
        self.maxwidth = 3064

        self.net = None
        with open('data/alphabet-all-v5.txt', 'r', encoding = 'utf-8') as fp :
@@ -561,18 +562,7 @@ class OCR32pxModel:
        self.device = device

    @torch.no_grad()
    def __call__(self, img: np.ndarray, textblk_lst: List[TextBlock], chunk_size = 16, regions: List = None, textblk_lst_indices: List = None) -> None:
        if isinstance(textblk_lst, TextBlock):
            textblk_lst = [textblk_lst]

        if regions is None or textblk_lst_indices is None:
            regions = []
            textblk_lst_indices = []
            for blk_idx, textblk in enumerate(textblk_lst):
                for ii in range(len(textblk)):
                    textblk_lst_indices.append(blk_idx)
                    region = textblk.get_transformed_region(img, ii, self.text_height, maxwidth=3064)
                    regions.append(region)
    def __call__(self, textblk_lst: List[TextBlock], regions: List[np.ndarray], textblk_lst_indices: List, chunk_size = 16) -> None:

        perm = range(len(regions))
        chunck_idx = 0
+4 −16
Original line number Diff line number Diff line
@@ -124,6 +124,8 @@ class Model48pxOCR:
        super().__init__(*args, **kwargs)

        self.device = device
        self.text_height = 48
        self.maxwidth = 8100

        with open('data/alphabet-all-v7.txt', 'r', encoding = 'utf-8') as fp:
            dictionary = [s[:-1] for s in fp.readlines()]
@@ -138,28 +140,14 @@ class Model48pxOCR:
        self.model.to(device)
        self.device = device
    
    def __call__(self, img: np.ndarray, textblk_lst: List[TextBlock], chunk_size = 16, regions: List = None, textblk_lst_indices: List = None) -> None:
        if isinstance(textblk_lst, TextBlock):
            textblk_lst = [textblk_lst]
        
        text_height = 48

        if regions is None or textblk_lst_indices is None:
            regions = []
            textblk_lst_indices = []
            for blk_idx, textblk in enumerate(textblk_lst):
                for ii in range(len(textblk)):
                    textblk_lst_indices.append(blk_idx)
                    region = textblk.get_transformed_region(img, ii, 48, maxwidth=8100)
                    regions.append(region)

    def __call__(self, textblk_lst: List[TextBlock], regions: List[np.ndarray], textblk_lst_indices: List, chunk_size = 16) -> None:
        perm = range(len(regions))
        chunck_idx = 0
        for indices in chunks(perm, chunk_size):
            N = len(indices)
            widths = [regions[i].shape[1] for i in indices]
            max_width = 4 * (max(widths) + 7) // 4
            region = np.zeros((N, text_height, max_width, 3), dtype = np.uint8)
            region = np.zeros((N, self.text_height, max_width, 3), dtype = np.uint8)
            for i, idx in enumerate(indices):
                W = regions[idx].shape[1]
                region[i, :, : W, :]=regions[idx]
+4 −13
Original line number Diff line number Diff line
@@ -396,6 +396,8 @@ class OCR48pxCTC:
        with open('data/alphabet-all-v5.txt', 'r', encoding = 'utf-8') as fp :
            dictionary = [s[:-1] for s in fp.readlines()]
        self.device = device
        self.text_height = 48
        self.maxwidth = 8100

        model = OCR(dictionary, 768)
        sd = torch.load(model_path, map_location = 'cpu')
@@ -413,18 +415,7 @@ class OCR48pxCTC:
        self.device = device

    @torch.no_grad()
    def __call__(self, img: np.ndarray, textblk_lst: List[TextBlock], chunk_size = 16, regions: List = None, textblk_lst_indices: List = None) -> None:
        if isinstance(textblk_lst, TextBlock):
            textblk_lst = [textblk_lst]
        
        if regions is None or textblk_lst_indices is None:
            regions = []
            textblk_lst_indices = []
            for blk_idx, textblk in enumerate(textblk_lst):
                for ii in range(len(textblk)):
                    textblk_lst_indices.append(blk_idx)
                    region = textblk.get_transformed_region(img, ii, 48, maxwidth=8100)
                    regions.append(region)
    def __call__(self, textblk_lst: List[TextBlock], regions: List[np.ndarray], textblk_lst_indices: List, chunk_size = 16) -> None:

        perm = range(len(regions))
        chunck_idx = 0
@@ -433,7 +424,7 @@ class OCR48pxCTC:
            widths = [regions[i].shape[1] for i in indices]
            # max_width = 4 * (max(widths) + 7) // 4
            max_width = (4 * (max(widths) + 7) // 4) + 128
            region = np.zeros((N, 48, max_width, 3), dtype = np.uint8)
            region = np.zeros((N, self.text_height, max_width, 3), dtype = np.uint8)
            for i, idx in enumerate(indices) :
                W = regions[idx].shape[1]
                region[i, :, : W, :] = regions[idx]
+1 −1
Original line number Diff line number Diff line
@@ -125,7 +125,7 @@ if platform.system() == 'Darwin' and platform.mac_ver()[0] >= '10.15':
            def ocr_img(self, img: np.ndarray) -> str:
                return self.model(img)

            def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
            def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock], *args, **kwargs):
                im_h, im_w = img.shape[:2]
                for blk in blk_list:
                    x1, y1, x2, y2 = blk.xyxy
Loading