Commit fc9b06a6 authored by dmMaze's avatar dmMaze
Browse files

support Keyword substitution for machine translation & ocr results

parent eda4bf87
Loading
Loading
Loading
Loading
+22 −4
Original line number Diff line number Diff line
from typing import Tuple, List, Dict, Union
from typing import Tuple, List, Dict, Union, Callable
from ordered_set import OrderedSet
import numpy as np
import cv2
import logging
@@ -20,6 +21,7 @@ class OCRBase(ModuleParamParser):
            if OCR.module_dict[key] == self.__class__:
                self.name = key
                break
        self.postprocess_hooks: OrderedSet[Callable] = OrderedSet()
        self.setup_ocr()

    def setup_ocr(self):
@@ -27,17 +29,33 @@ class OCRBase(ModuleParamParser):

    def run_ocr(self, img: np.ndarray, blk_list: List[TextBlock] = None) -> Union[List[TextBlock], str]:
        if blk_list is None:
            return self.ocr_img(img)
            text = self.ocr_img(img)
            for callback in self.postprocess_hooks:
                text = callback(text)
        elif isinstance(blk_list, TextBlock):
            blk_list = [blk_list]
        return self.ocr_blk_list(blk_list)

    def ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        self.ocr_blk_list(img, blk_list)
        for blk in blk_list:
            for ii, t in enumerate(blk.text):
                for callback in self.postprocess_hooks:
                    blk.text[ii] = callback(t, blk=blk)
        return blk_list

    def ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]) -> None:
        raise NotImplementedError

    def ocr_img(self, img: np.ndarray) -> str:
        raise NotImplementedError

    def register_postprocess_hooks(self, callbacks: Union[List, Callable]):
        if callbacks is None:
            return
        if isinstance(callbacks, Callable):
            callbacks = [callbacks]
        for callback in callbacks:
            self.postprocess_hooks.add(callback)


from .model_32px import OCR32pxModel
OCR32PXMODEL: OCR32pxModel = None
+19 −1
Original line number Diff line number Diff line
import urllib.request
from typing import Dict, List, Union, Set
from ordered_set import OrderedSet
from typing import Dict, List, Union, Set, Callable
import time, requests, re, uuid, base64, hmac, functools, json, deepl
import ctranslate2, sentencepiece as spm
from .exceptions import InvalidSourceOrTargetLanguage, TranslatorSetupFailure, MissingTranslatorParams, TranslatorNotValid
@@ -76,6 +77,7 @@ class TranslatorBase(ModuleParamParser):
        self.lang_source: str = lang_source
        self.lang_target: str = lang_target
        self.lang_map: Dict = LANGMAP_GLOBAL.copy()
        self.postprocess_hooks = OrderedSet()
        
        try:
            self.setup_translator()
@@ -99,6 +101,14 @@ class TranslatorBase(ModuleParamParser):
                self.set_source(lang_source)
                self.set_target(lang_target)

    def register_postprocess_hooks(self, callbacks: Union[List, Callable]):
        if callbacks is None:
            return
        if isinstance(callbacks, Callable):
            callbacks = [callbacks]
        for callback in callbacks:
            self.postprocess_hooks.add(callback)

    def _setup_translator(self):
        raise NotImplementedError

@@ -135,6 +145,12 @@ class TranslatorBase(ModuleParamParser):
            
        if isinstance(text, List):
            assert len(text_trans) == len(text)
            for ii, t in enumerate(text_trans):
                for callback in self.postprocess_hooks:
                    text_trans[ii] = callback(t)
        else:
            for callback in self.postprocess_hooks:
                text_trans = callback(text_trans)

        return text_trans

@@ -152,6 +168,8 @@ class TranslatorBase(ModuleParamParser):
        text_list = [blk.get_text() for blk in textblk_lst]
        translations = self.translate(text_list)
        for tr, blk in zip(translations, textblk_lst):
            for callback in self.postprocess_hooks:
                tr = callback(tr, blk=blk)
            blk.translation = tr

    def supported_languages(self) -> List[str]:
+6 −7
Original line number Diff line number Diff line
@@ -355,7 +355,6 @@ class ConfigPanel(Widget):
        generalConfigPanel.addTextLabel(label_saladict)

        sublock = ConfigSubBlock(ConfigTextLabel(self.tr("<a href=\"https://github.com/dmMaze/BallonsTranslator/tree/master/doc/saladict.md\">Installation guide</a>"), CONFIG_FONTSIZE_CONTENT - 2), vertical_layout=False)
        
        sublock.layout().addItem(QSpacerItem(0, 0, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding))
        generalConfigPanel.addSublock(sublock)

@@ -439,14 +438,14 @@ class ConfigPanel(Widget):
    def on_effect_flag_changed(self):
        self.config.let_fnteffect_flag = self.let_effect_combox.currentIndex()

    def on_source_flag_changed(self):
        self.config.src_choice_flag = self.src_choice_combox.currentIndex()
    # def on_source_flag_changed(self):
    #     self.config.src_choice_flag = self.src_choice_combox.currentIndex()

    def on_source_link_changed(self):
        self.config.src_link_flag = self.src_link_textbox.text()
    # def on_source_link_changed(self):
    #     self.config.src_link_flag = self.src_link_textbox.text()

    def on_source_force_download_changed(self):
        self.config.src_force_download_flag = self.src_force_download_checker.isChecked()
    # def on_source_force_download_changed(self):
    #     self.config.src_force_download_flag = self.src_force_download_checker.isChecked()

    def focusOnTranslator(self):
        idx0, idx1 = self.trans_sub_block.idx0, self.trans_sub_block.idx1
+7 −3
Original line number Diff line number Diff line
@@ -295,7 +295,7 @@ class ImgtransThread(QThread):

    def _blktrans_pipeline(self, blk_list: List[TextBlock], tgt_img: np.ndarray, mode: int):
        if mode >= 0:
            self.ocr_thread.module.ocr_blk_list(tgt_img, blk_list)
            self.ocr_thread.module.run_ocr(tgt_img, blk_list)
            self.finish_blktrans_stage.emit('ocr', 100)
        if mode != 0:
            self.translate_thread.module.translate_textblk_lst(blk_list)
@@ -347,7 +347,7 @@ class ImgtransThread(QThread):
            self.imgtrans_proj.pages[imgname] = blk_list

            if self.dl_config.enable_ocr:
                self.ocr.ocr_blk_list(img, blk_list)
                self.ocr.run_ocr(img, blk_list)
                self.ocr_counter += 1
                self.update_ocr_progress.emit(self.ocr_counter)

@@ -457,7 +457,7 @@ class DLManager(QObject):
        self.dl_config = config.dl
        self.imgtrans_proj = imgtrans_proj

    def setupThread(self, config_panel: ConfigPanel, imgtrans_progress_msgbox: ImgtransProgressMessageBox):
    def setupThread(self, config_panel: ConfigPanel, imgtrans_progress_msgbox: ImgtransProgressMessageBox, ocr_postprocess: Callable = None, translate_postprocess: Callable = None):
        dl_config = self.dl_config
        self.textdetect_thread = TextDetectThread(dl_config)
        self.textdetect_thread.finish_set_module.connect(self.on_finish_setdetector)
@@ -495,6 +495,7 @@ class DLManager(QObject):
        translator_panel.source_combobox.currentTextChanged.connect(self.on_translatorsource_changed)
        translator_panel.target_combobox.currentTextChanged.connect(self.on_translatortarget_changed)
        translator_panel.paramwidget_edited.connect(self.on_translatorparam_edited)
        self.translate_postprocess = translate_postprocess

        self.inpaint_panel = inpainter_panel = config_panel.inpaint_config_panel
        inpainter_setup_params = merge_config_module_params(dl_config.inpainter_setup_params, VALID_INPAINTERS, INPAINTERS.get)
@@ -515,6 +516,7 @@ class DLManager(QObject):
        ocr_panel.setupModulesParamWidgets(ocr_setup_params)
        ocr_panel.paramwidget_edited.connect(self.on_ocrparam_edited)
        ocr_panel.ocr_changed.connect(self.setOCR)
        self.ocr_postprocess = ocr_postprocess

        self.setTextDetector()
        self.setOCR()
@@ -705,6 +707,7 @@ class DLManager(QObject):
        if self.ocr is not None:
            self.dl_config.ocr = self.ocr.name
            self.ocr_panel.setOCR(self.ocr.name)
            self.ocr_thread.module.register_postprocess_hooks(self.ocr_postprocess)
            LOGGER.info('OCR set to {}'.format(self.ocr.name))

    def on_finish_setinpainter(self):
@@ -720,6 +723,7 @@ class DLManager(QObject):
            self.dl_config.translator = translator.name
            self.update_translator_status.emit(self.dl_config.translator, self.dl_config.translate_source, self.dl_config.translate_target)
            self.translator_panel.finishSetTranslator(translator)
            self.translate_thread.module.register_postprocess_hooks(self.translate_postprocess)
            LOGGER.info('Translator set to {}'.format(self.translator.name))
        else:
            LOGGER.error('invalid translator')
+18 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from typing import List
from dl import VALID_INPAINTERS, VALID_TEXTDETECTORS, VALID_TRANSLATORS, VALID_OCR, \
    TranslatorBase, DEFAULT_DEVICE
from utils.logger import logger as LOGGER
from .stylewidgets import ConfigComboBox
from .stylewidgets import ConfigComboBox, NoBorderPushBtn
from .constants import CONFIG_FONTSIZE_CONTENT, CONFIG_COMBOBOX_MIDEAN, CONFIG_COMBOBOX_SHORT, CONFIG_COMBOBOX_HEIGHT

from qtpy.QtWidgets import QHBoxLayout, QVBoxLayout, QWidget, QLabel, QComboBox, QCheckBox, QLineEdit
@@ -189,6 +189,8 @@ class ModuleConfigParseWidget(QWidget):

class TranslatorConfigPanel(ModuleConfigParseWidget):

    show_MT_keyword_window = Signal()

    def __init__(self, module_name, *args, **kwargs) -> None:
        super().__init__(module_name, VALID_TRANSLATORS, *args, **kwargs)
        self.translator_combobox = self.module_combobox
@@ -196,6 +198,10 @@ class TranslatorConfigPanel(ModuleConfigParseWidget):
    
        self.source_combobox = ConfigComboBox()
        self.target_combobox = ConfigComboBox()
        self.replaceMTkeywordBtn = NoBorderPushBtn(self.tr("Keyword substitution for machine translation"), self)
        self.replaceMTkeywordBtn.clicked.connect(self.show_MT_keyword_window)
        self.replaceMTkeywordBtn.setFixedWidth(500)

        st_layout = QHBoxLayout()
        st_layout.setSpacing(15)
        st_layout.setAlignment(Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter)
@@ -203,7 +209,9 @@ class TranslatorConfigPanel(ModuleConfigParseWidget):
        st_layout.addWidget(self.source_combobox)
        st_layout.addWidget(ParamNameLabel(self.tr('Target ')))
        st_layout.addWidget(self.target_combobox)
        
        self.vlayout.insertLayout(1, st_layout) 
        self.vlayout.addWidget(self.replaceMTkeywordBtn)

    def finishSetTranslator(self, translator: TranslatorBase):
        self.source_combobox.blockSignals(True)
@@ -242,8 +250,17 @@ class TextDetectConfigPanel(ModuleConfigParseWidget):


class OCRConfigPanel(ModuleConfigParseWidget):
    
    show_OCR_keyword_window = Signal()

    def __init__(self, module_name: str, *args, **kwargs) -> None:
        super().__init__(module_name, VALID_OCR, *args, **kwargs)
        self.ocr_changed = self.module_changed
        self.ocr_combobox = self.module_combobox
        self.setOCR = self.setModule

        self.replaceOCRkeywordBtn = NoBorderPushBtn(self.tr("Keyword substitution for OCR results"), self)
        self.replaceOCRkeywordBtn.clicked.connect(self.show_OCR_keyword_window)
        self.replaceOCRkeywordBtn.setFixedWidth(500)

        self.vlayout.addWidget(self.replaceOCRkeywordBtn)
 No newline at end of file
Loading