Commit 38e2232e authored by dmMaze's avatar dmMaze
Browse files

Support ocr translate inpaint selected text blocks

parent e160e455
Loading
Loading
Loading
Loading
+15 −2
Original line number Diff line number Diff line
@@ -29,8 +29,21 @@ class InpainterBase(ModuleParamParser):
    def setup_inpainter(self):
        raise NotImplementedError

    def inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
    def inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None, check_need_inpaint: bool = False) -> np.ndarray:
        if not self.inpaint_by_block or textblock_list is None:
            if check_need_inpaint:
                ballon_msk, non_text_msk = extract_ballon_mask(img, mask)
                if ballon_msk is not None:
                    non_text_region = np.where(non_text_msk > 0)
                    non_text_px = img[non_text_region]
                    average_bg_color = np.mean(non_text_px, axis=0)
                    std_bgr = np.std(non_text_px - average_bg_color, axis=0)
                    std_max = np.max(std_bgr)
                    inpaint_thresh = 7 if np.std(std_bgr) > 1 else 10
                    if std_max < inpaint_thresh:
                        img = img.copy()
                        img[np.where(ballon_msk > 0)] = average_bg_color
                        return img
            return self._inpaint(img, mask)
        else:
            im_h, im_w = img.shape[:2]
@@ -41,7 +54,7 @@ class InpainterBase(ModuleParamParser):
                im = inpainted[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]]
                msk = mask[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]]
                need_inpaint = True
                if self.check_need_inpaint:
                if self.check_need_inpaint or check_need_inpaint:
                    ballon_msk, non_text_msk = extract_ballon_mask(im, msk)
                    if ballon_msk is not None:
                        non_text_region = np.where(non_text_msk > 0)
+8 −1
Original line number Diff line number Diff line
from typing import Tuple, List, Dict, Union
import numpy as np
import cv2
import logging

from ..textdetector.textblock import TextBlock

@@ -139,9 +140,15 @@ class MangaOCR(OCRBase):
        return self.model(img)

    def ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        im_h, im_w = img.shape[:2]
        for blk in blk_list:
            x1, y1, x2, y2 = blk.xyxy
            if y2 < im_h and x2 < im_w and \
                x1 > 0 and y1 > 0 and x1 < x2 and y1 < y2: 
                blk.text = self.model(img[y1:y2, x1:x2])
            else:
                logging.warning('invalid textbbox to target img')
                blk.text = ''

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
+45 −2
Original line number Diff line number Diff line
@@ -97,7 +97,10 @@ class TextBlock(object):
        self.shadow_color = shadow_color
        self.shadow_offset = shadow_offset

    def adjust_bbox(self, with_bbox=False):
        self.region_mask: np.ndarray = None
        self.region_inpaint_dict: dict = None

    def adjust_bbox(self, with_bbox=False, x_range=None, y_range=None):
        lines = self.lines_array().astype(np.int32)
        if with_bbox:
            self.xyxy[0] = min(lines[..., 0].min(), self.xyxy[0])
@@ -110,6 +113,13 @@ class TextBlock(object):
            self.xyxy[2] = lines[..., 0].max()
            self.xyxy[3] = lines[..., 1].max()

        if x_range is not None:
            self.xyxy[0] = np.clip(self.xyxy[0], x_range[0], x_range[1])
            self.xyxy[2] = np.clip(self.xyxy[2], x_range[0], x_range[1])
        if y_range is not None:
            self.xyxy[1] = np.clip(self.xyxy[1], y_range[0], y_range[1])
            self.xyxy[3] = np.clip(self.xyxy[3], y_range[0], y_range[1])

    def sort_lines(self):
        if self.distance is not None:
            idx = np.argsort(self.distance)
@@ -120,6 +130,26 @@ class TextBlock(object):
    def lines_array(self, dtype=np.float64):
        return np.array(self.lines, dtype=dtype)

    def set_lines_by_xywh(self, xywh: np.ndarray, angle=0, x_range=None, y_range=None, adjust_bbox=False):
        if isinstance(xywh, List):
            xywh = np.array(xywh)
        lines = xywh2xyxypoly(np.array([xywh]))
        if angle != 0:
            cx, cy = xywh[0], xywh[1]
            cx += xywh[2] / 2.
            cy += xywh[3] / 2.
            lines = rotate_polygons([cx, cy], lines, angle)

        lines = lines.reshape(-1, 4, 2)
        if x_range is not None:
            lines[..., 0] = np.clip(lines[..., 0], x_range[0], x_range[1])
        if y_range is not None:
            lines[..., 1] = np.clip(lines[..., 1], y_range[0], y_range[1])
        self.lines = lines.tolist()

        if adjust_bbox:
            self.adjust_bbox()

    def aspect_ratio(self) -> float:
        min_rect = self.min_rect()
        middle_pnts = (min_rect[:, [1, 2, 3, 0]] + min_rect) / 2
@@ -188,25 +218,38 @@ class TextBlock(object):
    def get_transformed_region(self, img: np.ndarray, idx: int, textheight: int, maxwidth: int = None) -> np.ndarray :
        direction = 'v' if self.vertical else 'h'
        src_pts = np.array(self.lines[idx], dtype=np.float64)
        im_h, im_w = img.shape[:2]

        middle_pnt = (src_pts[[1, 2, 3, 0]] + src_pts) / 2
        vec_v = middle_pnt[2] - middle_pnt[0]   # vertical vectors of textlines
        vec_h = middle_pnt[1] - middle_pnt[3]   # horizontal vectors of textlines
        ratio = np.linalg.norm(vec_v) / np.linalg.norm(vec_h)
        norm_v = np.linalg.norm(vec_v)
        norm_h = np.linalg.norm(vec_h)
        if norm_v <= 0 or norm_h <= 0:
            print('invalid textpolygon to target img')
            return np.zeros((textheight, textheight, 3), dtype=np.uint8)
        ratio = norm_v / norm_h

        if direction == 'h' :
            h = int(textheight)
            w = int(round(textheight / ratio))
            dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32)
            M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
            if M is None:
                print('invalid textpolygon to target img')
                return np.zeros((textheight, textheight, 3), dtype=np.uint8)
            region = cv2.warpPerspective(img, M, (w, h))
        elif direction == 'v' :
            w = int(textheight)
            h = int(round(textheight * ratio))
            dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32)
            M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
            if M is None:
                print('invalid textpolygon to target img')
                return np.zeros((textheight, textheight, 3), dtype=np.uint8)
            region = cv2.warpPerspective(img, M, (w, h))
            region = cv2.rotate(region, cv2.ROTATE_90_COUNTERCLOCKWISE)

        if maxwidth is not None:
            h, w = region.shape[: 2]
            if w > maxwidth:
+17 −0
Original line number Diff line number Diff line
@@ -93,6 +93,8 @@ class Canvas(QGraphicsScene):
    format_textblks = Signal()
    layout_textblks = Signal()

    run_blktrans = Signal(int)

    begin_scale_tool = Signal(QPointF)
    scale_tool = Signal(QPointF)
    end_scale_tool = Signal()
@@ -542,7 +544,14 @@ class Canvas(QGraphicsScene):
            menu.addSeparator()
            format_act = menu.addAction(self.tr("Apply font formatting"))
            layout_act = menu.addAction(self.tr("Auto layout"))
            menu.addSeparator()
            translate_act = menu.addAction(self.tr("translate"))
            ocr_act = menu.addAction(self.tr("OCR"))
            ocr_translate_act = menu.addAction(self.tr("OCR and translate"))
            ocr_translate_inpaint_act = menu.addAction(self.tr("OCR, translate and inpaint"))

            rst = menu.exec_(event.screenPos())
            
            if rst == delete_act:
                self.delete_textblks.emit()
            elif rst == copy_act:
@@ -553,6 +562,14 @@ class Canvas(QGraphicsScene):
                self.format_textblks.emit()
            elif rst == layout_act:
                self.layout_textblks.emit()
            elif rst == translate_act:
                self.run_blktrans.emit(-1)
            elif rst == ocr_act:
                self.run_blktrans.emit(0)
            elif rst == ocr_translate_act:
                self.run_blktrans.emit(1)
            elif rst == ocr_translate_inpaint_act:
                self.run_blktrans.emit(2)
    
    def on_hide_canvas(self):
        self.clear_states()
+73 −6
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ from qtpy.QtWidgets import QMessageBox

from utils.logger import logger as LOGGER
from utils.registry import Registry
from utils.imgproc_utils import enlarge_window
from dl.translators import MissingTranslatorParams
from dl import INPAINTERS, TRANSLATORS, TEXTDETECTORS, OCR, \
    VALID_TRANSLATORS, VALID_TEXTDETECTORS, VALID_INPAINTERS, VALID_OCR, \
@@ -19,6 +20,8 @@ from .stylewidgets import ImgtransProgressMessageBox
from .configpanel import ConfigPanel
from .misc import DLModuleConfig, ProgramConfig
from .imgtrans_proj import ProjImgTrans
from dl.textdetector import TextBlock


class ModuleThread(QThread):

@@ -246,6 +249,8 @@ class ImgtransThread(QThread):
    update_inpaint_progress = Signal(int)
    exception_occurred = Signal(str, str)

    finish_blktrans_stage = Signal(str, int)

    def __init__(self, 
                 dl_config: DLModuleConfig, 
                 textdetect_thread: TextDetectThread,
@@ -284,6 +289,36 @@ class ImgtransThread(QThread):
        self.job = self._imgtrans_pipeline
        self.start()

    def runBlktransPipeline(self, blk_list: List[TextBlock], tgt_img: np.ndarray, mode: int):
        self.job = lambda : self._blktrans_pipeline(blk_list, tgt_img, mode)
        self.start()

    def _blktrans_pipeline(self, blk_list: List[TextBlock], tgt_img: np.ndarray, mode: int):
        if mode >= 0:
            self.ocr_thread.module.ocr_blk_list(tgt_img, blk_list)
            self.finish_blktrans_stage.emit('ocr', 100)
        if mode != 0:
            self.translate_thread.module.translate_textblk_lst(blk_list)
            self.finish_blktrans_stage.emit('translate', 100)
        if mode > 1:
            im_h, im_w = tgt_img.shape[:2]
            progress_prod = 100. / len(blk_list) if len(blk_list) > 0 else 0
            for ii, blk in enumerate(blk_list):
                xyxy = enlarge_window(blk.xyxy, im_w, im_h)
                xyxy = np.array(xyxy)
                x1, y1, x2, y2 = xyxy.astype(np.int64)
                blk.region_inpaint_dict = None
                if y2 - y1 > 2 and x2 - x1 > 2:
                    im = np.copy(tgt_img[y1: y2, x1: x2])
                    maskseg_method = self.get_maskseg_method()
                    inpaint_mask_array, ballon_mask, bub_dict = maskseg_method(im)
                    mask = self.post_process_mask(inpaint_mask_array)
                    if mask.sum() > 0:
                        inpainted = self.inpaint_thread.inpainter.inpaint(im, mask)
                        blk.region_inpaint_dict = {'img': im, 'mask': mask, 'inpaint_rect': [x1, y1, x2, y2], 'inpainted': inpainted}
                    self.finish_blktrans_stage.emit('inpaint', int((ii+1) * progress_prod))
        self.finish_blktrans_stage.emit(str(mode), 0)

    def _imgtrans_pipeline(self):
        self.detect_counter = 0
        self.ocr_counter = 0
@@ -409,6 +444,7 @@ class DLManager(QObject):
    canvas_inpaint_finished = Signal(dict)

    imgtrans_pipeline_finished = Signal()
    blktrans_pipeline_finished = Signal(int)
    page_trans_finished = Signal(int)

    run_canvas_inpaint = False
@@ -450,6 +486,7 @@ class DLManager(QObject):
        self.imgtrans_thread.update_translate_progress.connect(self.on_update_translate_progress)
        self.imgtrans_thread.update_inpaint_progress.connect(self.on_update_inpaint_progress)
        self.imgtrans_thread.exception_occurred.connect(self.handleRunTimeException)
        self.imgtrans_thread.finish_blktrans_stage.connect(self.on_finish_blktrans_stage)

        self.translator_panel = translator_panel = config_panel.trans_config_panel        
        translator_setup_params = merge_config_module_params(dl_config.translator_setup_params, VALID_TRANSLATORS, TRANSLATORS.get)
@@ -518,12 +555,7 @@ class DLManager(QObject):
            return
        self.inpaint_thread.inpaint(img, mask, img_key, inpaint_rect)

    def runImgtransPipeline(self):
        if self.imgtrans_proj.is_empty:
            LOGGER.info('proj file is empty, nothing to do')
            self.progress_msgbox.hide()
            return
        self.last_finished_index = -1
    def terminateRunningThread(self):
        if self.textdetect_thread.isRunning():
            self.textdetect_thread.terminate()
        if self.ocr_thread.isRunning():
@@ -533,6 +565,15 @@ class DLManager(QObject):
        if self.translate_thread.isRunning():
            self.translate_thread.terminate()

    def runImgtransPipeline(self):
        if self.imgtrans_proj.is_empty:
            LOGGER.info('proj file is empty, nothing to do')
            self.progress_msgbox.hide()
            return
        self.last_finished_index = -1
        self.terminateRunningThread()
        
        self.progress_msgbox.show_all_bars()
        if not self.dl_config.enable_ocr:
            self.progress_msgbox.ocr_bar.hide()
            self.progress_msgbox.translate_bar.hide()
@@ -546,6 +587,32 @@ class DLManager(QObject):
        self.progress_msgbox.show()
        self.imgtrans_thread.runImgtransPipeline(self.imgtrans_proj)

    def runBlktransPipeline(self, blk_list: List[TextBlock], tgt_img: np.ndarray, mode: int):
        self.terminateRunningThread()
        self.progress_msgbox.hide_all_bars()
        if mode >= 0:
            self.progress_msgbox.ocr_bar.show()
        if mode == 2:
            self.progress_msgbox.inpaint_bar.show()
        if mode != 0:
            self.progress_msgbox.translate_bar.show()
        self.progress_msgbox.zero_progress()
        self.progress_msgbox.show()
        self.imgtrans_thread.runBlktransPipeline(blk_list, tgt_img, mode)

    def on_finish_blktrans_stage(self, stage: str, progress: int):
        if stage == 'ocr':
            self.progress_msgbox.updateOCRProgress(progress)
        elif stage == 'translate':
            self.progress_msgbox.updateTranslateProgress(progress)
        elif stage == 'inpaint':
            self.progress_msgbox.updateInpaintProgress(progress)
        elif stage in {'0', '1', '2'}:
            self.blktrans_pipeline_finished.emit(int(stage))
            self.progress_msgbox.hide()
        else:
            raise NotImplementedError(f'Unknown stage: {stage}')

    def on_update_detect_progress(self, progress: int):
        ri = self.imgtrans_thread.recent_finished_index(progress)
        progress = int(progress / self.imgtrans_thread.num_pages * 100)
Loading