Commit 808f62fc authored by dmMaze's avatar dmMaze
Browse files

allow textblock label for ysgyolo, add mask dilate params to text detectors

parent 6eec0cdc
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -92,7 +92,8 @@ class BaseModule:
                try:
                    param_value = type(p)(param_value)
                except ValueError:
                    self.logger.warning(f'Invalid param value {param_value} for defined dtype: {type(p)}')
                    self.logger.warning(f'Invalid param value {param_value} for defined dtype: {type(p)}, revert to original value {p}')
                    param_value = p
            self.params[param_key] = param_value

    def updateParam(self, param_key: str, param_content):
+7 −0
Original line number Diff line number Diff line
import numpy as np
import cv2
from typing import Tuple, List

from .base import register_textdetectors, TextDetectorBase, TextBlock, DEFAULT_DEVICE, DEVICE_SELECTOR, ProjImgTrans
@@ -31,6 +32,7 @@ class ComicTextDetector(TextDetectorBase):
        'font size multiplier': 1.,
        'font size max': -1,
        'font size min': -1,
        'mask dilate size': 2
    }
    _load_model_keys = {'model'}
    download_file_list = [{
@@ -75,6 +77,11 @@ class ComicTextDetector(TextDetectorBase):
            blk.font_size = sz
            blk._detected_font_size = sz

        ksize = self.get_param_value('mask dilate size')
        if ksize > 0:
            element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * ksize + 1, 2 * ksize + 1),(ksize, ksize))
            mask = cv2.dilate(mask, element)

        return mask, blk_list

    def updateParam(self, param_key: str, param_content):
+61 −5
Original line number Diff line number Diff line
@@ -8,6 +8,8 @@ import cv2

from .base import register_textdetectors, TextDetectorBase, TextBlock, DEVICE_SELECTOR
from utils.textblock import mit_merge_textlines
from utils.textblock_mask import canny_flood
from utils.split_text_region import manga_split, split_textblock
from utils.imgproc_utils import xywh2xyxypoly
from utils.proj_imgtrans import ProjImgTrans

@@ -63,10 +65,13 @@ class YSGYoloDetector(TextDetectorBase):
                'vertical_textline': True, 
                'horizontal_textline': True, 
                'angled_vertical_textline': True, 
                'angled_horizontal_textline': True
                'angled_horizontal_textline': True,
                'textblock': True
            }, 
            'type': 'check_group'
        }
        },
        'source text is vertical': True,
        'mask dilate size': 2
    }

    _load_model_keys = {'yolo'}
@@ -81,7 +86,7 @@ class YSGYoloDetector(TextDetectorBase):
            self.yolo = YOLO(self.get_param_value('model path')).to(device=self.get_param_value('device'))

    def get_valid_labels(self):
        valid_labels = [k for k, v in self.params['label']['value'].items() if v]
        valid_labels = [k for k, v in self.params['label']['value'].items() if v and k != 'textblock']
        return valid_labels

    @property
@@ -97,17 +102,23 @@ class YSGYoloDetector(TextDetectorBase):
        )[0]
        valid_ids = []
        valid_labels = set(self.get_valid_labels())
        textblock_idx = -1
        for idx, name in result.names.items():
            if CLS_MAP[name] in valid_labels:
                valid_ids.append(idx)
            if name == 'qipao':
                textblock_idx = idx
        need_textblock = self.params['label']['value']['textblock'] == True

        mask = np.zeros_like(img[..., 0])
        if len(valid_ids) == 0:
        if len(valid_ids) == 0 and not need_textblock:
            return [], mask

        im_h, im_w = img.shape[:2]
        pts_list = []

        blk_list = []

        dets = result.boxes
        if dets is not None and len(dets.cls) > 0:
            device = dets.cls.device
@@ -126,6 +137,46 @@ class YSGYoloDetector(TextDetectorBase):
                xyxy_list[:, [2, 3]] -= xyxy_list[:, [0, 1]]
                pts_list += xywh2xyxypoly(xyxy_list).reshape(-1, 4, 2).tolist()
            
            if need_textblock:
                valid_mask = dets.cls == textblock_idx
                is_vertical = self.get_param_value('source text is vertical')
                if torch.any(valid_mask):
                    xyxy_list = dets.xyxy[valid_mask]
                    xyxy_list = xyxy_list.to(device='cpu', dtype=torch.float32).round().to(torch.int32)
                    xyxy_list[:, [0, 2]] = torch.clip(xyxy_list[:, [0, 2]], 0, im_w - 1)
                    xyxy_list[:, [1, 3]] = torch.clip(xyxy_list[:, [1, 3]], 0, im_h - 1)
                    xyxy_list = xyxy_list.numpy()
                    for xyxy in xyxy_list:
                        x1, y1, x2, y2 = xyxy
                        crop = img[y1: y2, x1: x2]
                        bmask  = canny_flood(crop)[0]
                        if is_vertical:
                            span_list = manga_split(bmask)
                            lines = [[line.left + x1, line.top + y1, line.width, line.height] for line in span_list]
                            lines = np.array(lines)[::-1]
                            font_sz = np.mean(lines[:, 2])
                        else:
                            span_list = split_textblock(bmask)[0]
                            lines = [[line.left + x1, line.top + y1, line.width, line.height] for line in span_list]
                            lines = np.array(lines)
                            font_sz = np.mean(lines[:, 3])
                        for line in lines:
                            x1, y1, x2, y2 = line
                            x2 += x1
                            y2 += y1
                            cv2.rectangle(mask, (x1, y1), (x2, y2), 255, -1)
                        lines = xywh2xyxypoly(lines).reshape(-1, 4, 2).tolist()
                        blk = TextBlock(xyxy=xyxy, lines=np.array(lines), src_is_vertical=is_vertical)
                        blk.font_size = font_sz
                        blk._detected_font_size = font_sz
                        blk_list.append(blk)
                        
                        # cv2.imwrite('mask.jpg', mask)
                        # for ii in range(len(blk.lines)):
                        #     rst = blk.get_transformed_region(img, ii, 48)
                        #     cv2.imwrite('local_tst.jpg', rst)
                        #     pass

        # oriented objects
        dets = result.obb
        if dets is not None and len(dets.cls) > 0:
@@ -143,7 +194,7 @@ class YSGYoloDetector(TextDetectorBase):
                    cv2.fillPoly(mask, [pts], 255)
                pts_list += xyxy_list.tolist()

        blk_list: List[TextBlock] = mit_merge_textlines(pts_list, width=im_w, height=im_h)
        blk_list += mit_merge_textlines(pts_list, width=im_w, height=im_h)

        fnt_rsz = self.get_param_value('font size multiplier')
        fnt_max = self.get_param_value('font size max')
@@ -157,6 +208,11 @@ class YSGYoloDetector(TextDetectorBase):
            blk.font_size = sz
            blk._detected_font_size = sz

        ksize = self.get_param_value('mask dilate size')
        if ksize > 0:
            element = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * ksize + 1, 2 * ksize + 1),(ksize, ksize))
            mask = cv2.dilate(mask, element)
            
        return mask, blk_list

    def updateParam(self, param_key: str, param_content):
+0 −1
Original line number Diff line number Diff line
@@ -281,7 +281,6 @@ class MainWindow(mainwindow_cls):
        module_manager.page_trans_finished.connect(self.on_pagtrans_finished)
        module_manager.setupThread(self.configPanel, self.imgtrans_progress_msgbox, self.ocr_postprocess, self.translate_preprocess, self.translate_postprocess)
        module_manager.progress_msgbox.showed.connect(self.on_imgtrans_progressbox_showed)
        module_manager.imgtrans_thread.mask_postprocess = self.drawingPanel.rectPanel.post_process_mask
        module_manager.blktrans_pipeline_finished.connect(self.on_blktrans_finished)
        module_manager.imgtrans_thread.post_process_mask = self.drawingPanel.rectPanel.post_process_mask

+0 −3
Original line number Diff line number Diff line
@@ -284,7 +284,6 @@ class ImgtransThread(QThread):
        self.inpaint_thread = inpaint_thread
        self.job = None
        self.imgtrans_proj: ProjImgTrans = None
        self.mask_postprocess = None

    @property
    def textdetector(self) -> TextDetectorBase:
@@ -365,8 +364,6 @@ class ImgtransThread(QThread):
            if cfg_module.enable_detect:
                try:
                    mask, blk_list = self.textdetector.detect(img, self.imgtrans_proj)
                    if self.mask_postprocess is not None:
                        mask = self.mask_postprocess(mask)
                    need_save_mask = True
                except Exception as e:
                    create_error_dialog(e, self.tr('Text Detection Failed.'), 'TextDetectFailed')
Loading