Commit e56e9cc3 authored by dmMaze's avatar dmMaze
Browse files

fix Use Existing Mask for blktrans

parent c81cd1b0
Loading
Loading
Loading
Loading
+27 −43
Original line number Diff line number Diff line
@@ -7,9 +7,10 @@ import numpy as np
import cv2

from utils.imgproc_utils import enlarge_window
from utils.textblock_mask import canny_flood, connected_canny_flood, existing_mask
from utils.textblock_mask import canny_flood, connected_canny_flood
from utils.logger import logger

from utils.config import pcfg
from .funcmaps import get_maskseg_method
from .module_manager import ModuleManager
from .image_edit import ImageEditMode, PenShape, PixmapItem, StrokeImgItem
from .configpanel import InpaintConfigPanel
@@ -170,7 +171,6 @@ class PenConfigPanel(Widget):
            self.thicknessChanged.emit(self.thicknessSlider.value())

    def on_alpha_changed(self):
        if self.alphaSlider.hasFocus():
        color = self.colorPicker.rgba()
        color = (color[0], color[1], color[2], self.alphaSlider.value())
        self.colorPicker.setPickerColor(color)
@@ -206,6 +206,7 @@ class RectPanel(Widget):
            self.tr('method 2'),
            self.tr('Use Existing Mask')
        ])
        self.methodComboBox.activated.connect(self.on_inpaint_seg_method_changed)
        self.autoChecker = QCheckBox(self.tr("Auto"))
        self.autoChecker.setToolTip(self.tr("run inpainting automatically."))
        self.autoChecker.stateChanged.connect(self.on_auto_changed)
@@ -244,19 +245,16 @@ class RectPanel(Widget):
        self.inpaint_layout.removeWidget(self.inpainter_panel.module_combobox)
        return super().hideEvent(e)
        
    def get_maskseg_method(self):
        if self.methodComboBox.currentIndex() == 0:
            return canny_flood
        elif self.methodComboBox.currentIndex() == 1:
            return connected_canny_flood
        elif self.methodComboBox.currentIndex() == 2:
            return existing_mask
    def on_inpaint_seg_method_changed(self):
        pcfg.drawpanel.rectool_method = self.methodComboBox.currentIndex()

    def on_auto_changed(self):
        if self.autoChecker.isChecked():
            self.inpaint_btn.hide()
            self.delete_btn.hide()
            pcfg.drawpanel.rectool_auto = True
        else:
            pcfg.drawpanel.rectool_auto = False
            self.inpaint_btn.show()
            self.delete_btn.show()

@@ -340,8 +338,8 @@ class DrawingPanel(Widget):
        self.canvas.erasing_pen = self.erasing_pen = QPen(Qt.GlobalColor.black, 1, Qt.PenStyle.SolidLine, Qt.PenCapStyle.RoundCap, Qt.PenJoinStyle.RoundJoin)
        self.inpaint_pen = QPen(INPAINT_BRUSH_COLOR, 1, Qt.PenStyle.SolidLine, Qt.PenCapStyle.RoundCap, Qt.PenJoinStyle.RoundJoin)
        
        self.setPenToolWidth(10)
        self.setPenToolColor([0, 0, 0, 127])
        # self.setPenToolWidth(10)
        # self.setPenToolColor([0, 0, 0, 127])

        self.toolConfigStackwidget = QStackedWidget()
        self.toolConfigStackwidget.setSizePolicy(QSizePolicy.Policy.Maximum, QSizePolicy.Policy.Minimum)
@@ -392,16 +390,19 @@ class DrawingPanel(Widget):

    def setInpaintToolWidth(self, width):
        self.inpaint_pen.setWidthF(width)
        pcfg.drawpanel.inpainter_width = width
        if self.isVisible():
            self.setInpaintCursor()

    def setInpaintShape(self, shape: int):
        self.setInpaintCursor()
        pcfg.drawpanel.inpainter_shape = shape
        self.canvas.painting_shape = shape

    def setPenToolWidth(self, width):
        self.pentool_pen.setWidthF(width)
        self.erasing_pen.setWidthF(width)
        pcfg.drawpanel.pentool_width = self.pentool_pen.widthF()
        if self.isVisible():
            self.setPenCursor()

@@ -409,6 +410,7 @@ class DrawingPanel(Widget):
        if not isinstance(color, QColor):
            color = QColor(*color)
        self.pentool_pen.setColor(color)
        pcfg.drawpanel.pentool_color = [color.red(), color.green(), color.blue(), color.alpha()]
        if self.isVisible():
            self.setPenCursor()
        self.penConfigPanel.colorPicker.setPickerColor(color)
@@ -417,11 +419,13 @@ class DrawingPanel(Widget):
    def setPenShape(self, shape: int):
        self.setPenCursor()
        self.canvas.painting_shape = shape
        pcfg.drawpanel.pentool_shape = shape

    def on_use_handtool(self) -> None:
        if self.currentTool is not None and self.currentTool != self.handTool:
            self.currentTool.setChecked(False)
        self.currentTool = self.handTool
        pcfg.drawpanel.current_tool = ImageEditMode.HandTool
        self.canvas.gv.setDragMode(QGraphicsView.DragMode.ScrollHandDrag)
        self.canvas.image_edit_mode = ImageEditMode.HandTool

@@ -429,6 +433,7 @@ class DrawingPanel(Widget):
        if self.currentTool is not None and self.currentTool != self.inpaintTool:
            self.currentTool.setChecked(False)
        self.currentTool = self.inpaintTool
        pcfg.drawpanel.current_tool = ImageEditMode.InpaintTool
        self.canvas.image_edit_mode = ImageEditMode.InpaintTool
        self.canvas.painting_pen = self.inpaint_pen
        self.canvas.erasing_pen = self.inpaint_pen
@@ -442,6 +447,7 @@ class DrawingPanel(Widget):
        if self.currentTool is not None and self.currentTool != self.penTool:
            self.currentTool.setChecked(False)
        self.currentTool = self.penTool
        pcfg.drawpanel.current_tool = ImageEditMode.PenTool
        self.canvas.painting_pen = self.pentool_pen
        self.canvas.painting_shape = self.penConfigPanel.shape
        self.canvas.erasing_pen = self.erasing_pen
@@ -455,34 +461,12 @@ class DrawingPanel(Widget):
        if self.currentTool is not None and self.currentTool != self.rectTool:
            self.currentTool.setChecked(False)
        self.currentTool = self.rectTool
        pcfg.drawpanel.current_tool = ImageEditMode.RectTool
        self.toolConfigStackwidget.setCurrentWidget(self.rectPanel)
        self.canvas.gv.setDragMode(QGraphicsView.DragMode.NoDrag)
        self.canvas.image_edit_mode = ImageEditMode.RectTool
        self.setCrossCursor()

    def get_config(self) -> DrawPanelConfig:
        config = DrawPanelConfig()
        pc = self.pentool_pen.color()
        config.pentool_color = [pc.red(), pc.green(), pc.blue(), pc.alpha()]
        config.pentool_width = self.pentool_pen.widthF()
        config.pentool_shape = self.penConfigPanel.shape

        config.inpainter_width = self.inpaint_pen.widthF()
        config.inpainter_shape = self.penConfigPanel.shape

        if self.currentTool == self.handTool:
            config.current_tool = ImageEditMode.HandTool
        elif self.currentTool == self.inpaintTool:
            config.current_tool = ImageEditMode.InpaintTool
        elif self.currentTool == self.penTool:
            config.current_tool = ImageEditMode.PenTool
        elif self.currentTool == self.rectTool:
            config.current_tool = ImageEditMode.RectTool
        config.recttool_dilate_ksize = self.rectPanel.dilate_slider.value()
        config.rectool_auto = self.rectPanel.autoChecker.isChecked()
        config.rectool_method = self.rectPanel.methodComboBox.currentIndex()
        return config

    def set_config(self, config: DrawPanelConfig):
        self.setPenToolWidth(config.pentool_width)
        self.setPenToolColor(config.pentool_color)
@@ -780,9 +764,8 @@ class DrawingPanel(Widget):
                return
            if mode == 0:
                im = np.copy(img[y1: y2, x1: x2])
                maskseg_method = self.rectPanel.get_maskseg_method()
                mask = self.canvas.imgtrans_proj.mask_array[y1: y2, x1: x2]
                inpaint_mask_array, ballon_mask, bub_dict = maskseg_method(im, mask=mask)
                maskseg_method = get_maskseg_method()
                inpaint_mask_array, ballon_mask, bub_dict = maskseg_method(im, mask=self.canvas.imgtrans_proj.mask_array[y1: y2, x1: x2])
                mask = self.rectPanel.post_process_mask(inpaint_mask_array)

                bground_bgr = bub_dict['bground_bgr']
@@ -830,6 +813,7 @@ class DrawingPanel(Widget):
        self.clearInpaintItems()

    def on_rectool_ksize_changed(self):
        pcfg.drawpanel.recttool_dilate_ksize = self.rectPanel.dilate_slider.value()
        if self.currentTool != self.rectTool or self.inpaint_mask_array is None or self.inpaint_mask_item is None:
            return
        mask = self.rectPanel.post_process_mask(self.inpaint_mask_array)
+7 −2
Original line number Diff line number Diff line
from utils.io_utils import build_funcmap
from utils.fontformat import FontFormat

from utils.config import pcfg
from utils.textblock_mask import canny_flood, connected_canny_flood, existing_mask

handle_ffmt_change = build_funcmap('ui.fontformat_commands', 
                                     list(FontFormat.params().keys()), 
                                     'ffmt_change_', verbose=False)


def get_maskseg_method():
    return [canny_flood, connected_canny_flood, existing_mask][pcfg.drawpanel.rectool_method]
 No newline at end of file
+2 −3
Original line number Diff line number Diff line
@@ -272,7 +272,6 @@ class MainWindow(mainwindow_cls):
        module_manager.progress_msgbox.showed.connect(self.on_imgtrans_progressbox_showed)
        module_manager.imgtrans_thread.mask_postprocess = self.drawingPanel.rectPanel.post_process_mask
        module_manager.blktrans_pipeline_finished.connect(self.on_blktrans_finished)
        module_manager.imgtrans_thread.get_maskseg_method = self.drawingPanel.rectPanel.get_maskseg_method
        module_manager.imgtrans_thread.post_process_mask = self.drawingPanel.rectPanel.post_process_mask

        self.leftBar.run_imgtrans.connect(self.on_run_imgtrans)
@@ -449,7 +448,6 @@ class MainWindow(mainwindow_cls):
            self.restart_signal.emit()

    def save_config(self):
        pcfg.drawpanel = self.drawingPanel.get_config()
        save_config()

    def onHideCanvas(self):
@@ -880,6 +878,7 @@ class MainWindow(mainwindow_cls):
        tgt_img = self.imgtrans_proj.img_array
        if tgt_img is None:
            return False
        tgt_mask = self.imgtrans_proj.mask_array
        
        if len(blkitem_list) < 1:
            return False
@@ -897,7 +896,7 @@ class MainWindow(mainwindow_cls):
            blk.set_lines_by_xywh(blk._bounding_rect, angle=-blk.angle, x_range=[0, im_w-1], y_range=[0, im_h-1], adjust_bbox=True)
            blk_list.append(blk)

        self.module_manager.runBlktransPipeline(blk_list, tgt_img, mode, blk_ids)
        self.module_manager.runBlktransPipeline(blk_list, tgt_img, mode, blk_ids, tgt_mask = tgt_mask)
        return True


+9 −8
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ from typing import Union, List, Dict, Callable
import numpy as np
from qtpy.QtCore import QThread, Signal, QObject, QLocale, QTimer

from .funcmaps import get_maskseg_method
from utils.logger import logger as LOGGER
from utils.registry import Registry
from utils.imgproc_utils import enlarge_window, get_block_mask
@@ -305,14 +306,14 @@ class ImgtransThread(QThread):
        self.job = self._imgtrans_pipeline
        self.start()

    def runBlktransPipeline(self, blk_list: List[TextBlock], tgt_img: np.ndarray, mode: int, blk_ids: List[int]):
        self.job = lambda : self._blktrans_pipeline(blk_list, tgt_img, mode, blk_ids)
    def runBlktransPipeline(self, blk_list: List[TextBlock], tgt_img: np.ndarray, mode: int, blk_ids: List[int], tgt_mask):
        self.job = lambda : self._blktrans_pipeline(blk_list, tgt_img, mode, blk_ids, tgt_mask)
        self.start()

    def _blktrans_pipeline(self, blk_list: List[TextBlock], tgt_img: np.ndarray, mode: int, blk_ids: List[int]):
    def _blktrans_pipeline(self, blk_list: List[TextBlock], tgt_img: np.ndarray, mode: int, blk_ids: List[int], tgt_mask):
        if mode >= 0 and mode < 3:
            try:
                self.ocr_thread.module.run_ocr(tgt_img, blk_list, split_textblk=True, seg_func=self.get_maskseg_method())
                self.ocr_thread.module.run_ocr(tgt_img, blk_list, split_textblk=True)
            except Exception as e:
                create_error_dialog(e, self.tr('OCR Failed.'), 'OCRFailed')
            self.finish_blktrans.emit(mode, blk_ids)
@@ -330,8 +331,8 @@ class ImgtransThread(QThread):
                blk.region_inpaint_dict = None
                if y2 - y1 > 2 and x2 - x1 > 2:
                    im = np.copy(tgt_img[y1: y2, x1: x2])
                    maskseg_method = self.get_maskseg_method()
                    inpaint_mask_array, ballon_mask, bub_dict = maskseg_method(im)
                    maskseg_method = get_maskseg_method()
                    inpaint_mask_array, ballon_mask, bub_dict = maskseg_method(im, mask=tgt_mask[y1: y2, x1: x2])
                    mask = self.post_process_mask(inpaint_mask_array)
                    if mask.sum() > 0:
                        inpainted = self.inpaint_thread.inpainter.inpaint(im, mask)
@@ -746,7 +747,7 @@ class ModuleManager(QObject):
        self.progress_msgbox.show()
        self.imgtrans_thread.runImgtransPipeline(self.imgtrans_proj)

    def runBlktransPipeline(self, blk_list: List[TextBlock], tgt_img: np.ndarray, mode: int, blk_ids: List[int]):
    def runBlktransPipeline(self, blk_list: List[TextBlock], tgt_img: np.ndarray, mode: int, blk_ids: List[int], tgt_mask):
        self.terminateRunningThread()
        self.progress_msgbox.hide_all_bars()
        if mode >= 0 and mode < 3:
@@ -757,7 +758,7 @@ class ModuleManager(QObject):
            self.progress_msgbox.translate_bar.show()
        self.progress_msgbox.zero_progress()
        self.progress_msgbox.show()
        self.imgtrans_thread.runBlktransPipeline(blk_list, tgt_img, mode, blk_ids)
        self.imgtrans_thread.runBlktransPipeline(blk_list, tgt_img, mode, blk_ids, tgt_mask)

    def on_finish_blktrans_stage(self, stage: str, progress: int):
        if stage == 'ocr':
+2 −1
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ from .structures import Tuple, Union, List, Dict, field, nested_dataclass
from .split_text_region import split_textblock as split_text_region
from .fontformat import FontFormat, LineSpacingType, TextAlignment, fix_fontweight_qt
from . import shared
from .textblock_mask import canny_flood


LANG_LIST = ['eng', 'ja', 'unknown']
@@ -806,7 +807,7 @@ def collect_textblock_regions(img: np.ndarray, textblk_lst: List[TextBlock], tex
    for blk_idx, textblk in enumerate(textblk_lst):
        for ii in range(len(textblk)):
            if split_textblk and len(textblk) == 1:
                assert seg_func is not None
                seg_func = canny_flood
                region = textblk.get_transformed_region(img, ii, None, maxwidth=None)
                mask  = seg_func(region)[0]
                split_lines = split_text_region(mask)[0]