Commit 983b6305 authored by dmMaze's avatar dmMaze
Browse files

support disable/enable certain stages (#190)

parent f3b4bdd2
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -57,6 +57,7 @@ class TextBlock:
    shadow_strength: float = 1.
    shadow_color: Tuple = (0, 0, 0)
    shadow_offset: List = field(default_factory = lambda : [0., 0.])
    src_is_vertical: bool = False

    region_mask: np.ndarray = None
    region_inpaint_dict: Dict = None
@@ -185,9 +186,8 @@ class TextBlock:
        return blk_dict

    def get_transformed_region(self, img: np.ndarray, idx: int, textheight: int, maxwidth: int = None) -> np.ndarray :
        direction = 'v' if self.vertical else 'h'
        direction = 'v' if self.src_is_vertical else 'h'
        src_pts = np.array(self.lines[idx], dtype=np.float64)
        im_h, im_w = img.shape[:2]

        middle_pnt = (src_pts[[1, 2, 3, 0]] + src_pts) / 2
        vec_v = middle_pnt[2] - middle_pnt[0]   # vertical vectors of textlines
@@ -223,6 +223,7 @@ class TextBlock:
            h, w = region.shape[: 2]
            if w > maxwidth:
                region = cv2.resize(region, (maxwidth, h))

        return region

    def get_text(self):
@@ -397,7 +398,7 @@ def examine_textblk(blk: TextBlock, im_w: int, im_h: int, sort: bool = False) ->
    if abs(blk.angle) < 3:
        blk.angle = 0
    blk.font_size = font_size
    blk.vertical = vertical
    blk.vertical = blk.src_is_vertical = vertical
    blk.vec = primary_vec
    blk.norm = primary_norm
    if sort:
+1 −0
Original line number Diff line number Diff line
@@ -39,6 +39,7 @@ class SizeComboBox(QComboBox):
    def on_text_changed(self):
        if self.hasFocus():
            self.text_changed_by_user = True
            self.check_change()

    def on_current_index_changed(self):
        if self.hasFocus():
+19 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ class ModuleConfig(Config):
    ocr: str = "mit48px_ctc"
    inpainter: str = 'lama_mpe'
    translator: str = "google"
    enable_detect: bool = True
    enable_ocr: bool = True
    enable_translate: bool = True
    enable_inpaint: bool = True
@@ -25,6 +26,22 @@ class ModuleConfig(Config):
    def get_params(self, module_key: str) -> dict:
        return self[module_key + '_params']
    
    def stage_enabled(self, idx: int):
        if idx == 0:
            return self.enable_detect
        elif idx == 1:
            return self.enable_ocr
        elif idx == 2:
            return self.enable_translate
        elif idx == 3:
            return self.enable_inpaint
        else:
            raise Exception(f'not supported stage idx: {idx}')
        
    def all_stages_disabled(self):
        return (self.enable_detect or self.enable_ocr or self.enable_translate or self.enable_inpaint) is False
        

@nested_dataclass
class DrawPanelConfig(Config):
    pentool_color: List = field(default_factory=lambda: [0, 0, 0])
@@ -79,6 +96,8 @@ class ProgramConfig(Config):
    mt_sublist: dict = field(default_factory=lambda: [])
    display_lang: str = C.DEFAULT_DISPLAY_LANG



    @staticmethod
    def load(cfg_path: str):
        
+19 −6
Original line number Diff line number Diff line
@@ -245,9 +245,6 @@ class MainWindow(FramelessWindow):
        self.show_trans_text(pcfg.show_trans_text)
        self.show_source_text(pcfg.show_source_text)

        self.bottomBar.ocrChecker.setCheckState(pcfg.module.enable_ocr)
        self.bottomBar.transChecker.setChecked(pcfg.module.enable_translate)

        self.module_manager = module_manager = ModuleManager(self.imgtrans_proj)
        module_manager.update_translator_status.connect(self.updateTranslatorStatus)
        module_manager.update_source_download_status.connect(self.updateSourceDownloadStatus)
@@ -265,8 +262,6 @@ class MainWindow(FramelessWindow):

        self.leftBar.run_imgtrans.connect(self.on_run_imgtrans)
        self.leftBar.run_sync_source.connect(self.on_run_sync_source)
        self.bottomBar.ocrcheck_statechanged.connect(module_manager.setOCRMode)
        self.bottomBar.transcheck_statechanged.connect(module_manager.setTransMode)
        self.bottomBar.inpaint_btn_clicked.connect(self.inpaintBtnClicked)
        self.bottomBar.source_download_btn_clicked.connect(self.SourceDownloadBtnClicked)
        self.bottomBar.translatorStatusbtn.clicked.connect(self.translatorStatusBtnPressed)
@@ -935,7 +930,7 @@ class MainWindow(FramelessWindow):
            blk.line_spacing = gf.line_spacing
            blk.letter_spacing = gf.letter_spacing
            sw = blk.stroke_width
            if sw > 0:
            if sw > 0 and pcfg.module.enable_ocr:
                blk.font_size = int(blk.font_size / (1 + sw))

        self.st_manager.auto_textlayout_flag = pcfg.let_autolayout_flag
@@ -1011,6 +1006,24 @@ class MainWindow(FramelessWindow):
        if self.bottomBar.textblockChecker.isChecked():
            self.bottomBar.textblockChecker.click()
        self.postprocess_mt_toggle = False

        all_disabled = pcfg.module.all_stages_disabled()
        if pcfg.module.enable_detect:
            for page in self.imgtrans_proj.pages:
                self.imgtrans_proj.pages[page].clear()
        else:
            self.st_manager.updateTextBlkList()
            textblk: TextBlock = None
            for blklist in self.imgtrans_proj.pages.values():
                for textblk in blklist:
                    if pcfg.module.enable_ocr:
                        textblk.stroke_decide_by_colordiff = True
                        textblk.default_stroke_width = 0.2
                        textblk.text = []
                        textblk.set_font_colors((0, 0, 0), (0, 0, 0), True)
                    if pcfg.module.enable_translate or all_disabled:
                        textblk.rich_text = ''
                    textblk.vertical = textblk.src_is_vertical
        self.module_manager.runImgtransPipeline()

    def on_run_sync_source(self):
+28 −19
Original line number Diff line number Diff line
@@ -384,9 +384,23 @@ class TitleBar(Widget):

        self.runToolBtn = TitleBarToolBtn(self)
        self.runToolBtn.setText(self.tr('Run'))

        self.stageActions = stageActions = [
            QAction(self.tr('Enable Text Dection'), self),
            QAction(self.tr('Enable OCR'), self),
            QAction(self.tr('Enable Translation'), self),
            QAction(self.tr('Enable Inpainting'), self)
        ]
        for idx, sa in enumerate(stageActions):
            sa.setCheckable(True)
            sa.setChecked(pcfg.module.stage_enabled(idx))
            sa.triggered.connect(self.stageEnableStateChanged)

        runAction = QAction(self.tr('Run'), self)
        translatePageAction = QAction(self.tr('Translate page'), self)
        runMenu = QMenu(self.runToolBtn)
        runMenu.addActions(stageActions)
        runMenu.addSeparator()
        runMenu.addActions([runAction, translatePageAction])
        self.runToolBtn.setMenu(runMenu)
        self.runToolBtn.setPopupMode(QToolButton.InstantPopup)
@@ -425,6 +439,19 @@ class TitleBar(Widget):
        hlayout.setContentsMargins(0, 0, 0, 0)
        hlayout.setSpacing(0)

    def stageEnableStateChanged(self):
        sender = self.sender()
        idx= self.stageActions.index(sender)
        checked = sender.isChecked()
        if idx == 0:
            pcfg.module.enable_detect = checked
        elif idx == 1:
            pcfg.module.enable_ocr = checked
        elif idx == 2:
            pcfg.module.enable_translate = checked
        elif idx == 3:
            pcfg.module.enable_inpaint = checked

    def onMaxBtnClicked(self):
        if self.mainwindow.isMaximized():
            self.mainwindow.showNormal()
@@ -499,8 +526,6 @@ class BottomBar(Widget):
    textedit_checkchanged = Signal()
    paintmode_checkchanged = Signal()
    textblock_checkchanged = Signal()
    ocrcheck_statechanged = Signal(bool)
    transcheck_statechanged = Signal(bool)
    inpaint_btn_clicked = Signal()
    source_download_btn_clicked = Signal()

@@ -510,14 +535,6 @@ class BottomBar(Widget):
        self.setMouseTracking(True)
        self.mainwindow = mainwindow

        self.ocrChecker = TextChecker('ocr')
        self.ocrChecker.setObjectName('OCRChecker')
        self.ocrChecker.setToolTip(self.tr('Enable/disable ocr'))
        self.ocrChecker.checkStateChanged.connect(self.OCRStateChanged)
        self.transChecker = QCheckBox()
        self.transChecker.setObjectName('TransChecker')
        self.transChecker.setToolTip(self.tr('Enable/disable translation'))
        self.transChecker.clicked.connect(self.transCheckerStateChanged)
        self.translatorStatusbtn = TranslatorStatusButton()
        self.translatorStatusbtn.setHidden(True)
        self.transTranspageBtn = RunStopTextBtn(self.tr('translate page'),
@@ -551,8 +568,6 @@ class BottomBar(Widget):
        self.textlayerSlider.setValue(100)
        self.textlayerSlider.setRange(0, 100)
        
        self.hlayout.addWidget(self.ocrChecker)
        self.hlayout.addWidget(self.transChecker)
        self.hlayout.addWidget(self.translatorStatusbtn)
        self.hlayout.addWidget(self.transTranspageBtn)
        self.hlayout.addWidget(self.inpainterStatBtn)
@@ -579,12 +594,6 @@ class BottomBar(Widget):
    def onTextblockCheckerClicked(self):
        self.textblock_checkchanged.emit()

    def OCRStateChanged(self):
        self.ocrcheck_statechanged.emit(self.ocrChecker.isChecked())
        
    def transCheckerStateChanged(self):
        self.transcheck_statechanged.emit(self.transChecker.isChecked())

    def inpaintBtnClicked(self):
        self.inpaint_btn_clicked.emit()

Loading