Unverified Commit 29e78d52 authored by dmMaze's avatar dmMaze Committed by GitHub
Browse files

Merge pull request #21 from dmMaze/sugoi_translator

Update translate module & add Sugoi translator
parents e25c82a0 565f10ed
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -77,6 +77,12 @@ class OpenCVInpainter(InpainterBase):
    def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
        return self.inpaint_method(img, mask)

    def is_computational_intensive(self) -> bool:
        return True
    
    def is_cpu_intensive(self) -> bool:
        return True


@register_inpainter('patchmatch')
class PatchmatchInpainter(InpainterBase):
@@ -88,6 +94,12 @@ class PatchmatchInpainter(InpainterBase):
    def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
        return self.inpaint_method(img, mask)

    def is_computational_intensive(self) -> bool:
        return True
    
    def is_cpu_intensive(self) -> bool:
        return True


import torch
from utils.imgproc_utils import resize_keepasp
+17 −0
Original line number Diff line number Diff line
from typing import Dict

GPUINTENSIVE_SET = {'cuda', 'hip'}

class ModuleParamParser:

    setup_params: Dict = None
@@ -16,6 +18,21 @@ class ModuleParamParser:
            if param_dict['type'] == 'selector':
                param_dict['select'] = param_content

    def is_cpu_intensive(self)->bool:
        if self.setup_params is not None and 'device' in self.setup_params:
            return self.setup_params['device']['select'] == 'cpu'
        return False

    def is_gpu_intensive(self) -> bool:
        if self.setup_params is not None and 'device' in self.setup_params:
            return self.setup_params['device']['select'] in GPUINTENSIVE_SET
        return False

    def is_computational_intensive(self) -> bool:
        if self.setup_params is not None and 'device' in self.setup_params:
            return True
        return False


import torch

+100 −41
Original line number Diff line number Diff line
import urllib.request
from typing import Dict, List, Union
import time, requests, re, uuid, base64, hmac
import functools
import json
from typing import Dict, List, Union, Set
import time, requests, re, uuid, base64, hmac, functools, json, deepl
import ctranslate2, sentencepiece as spm
from .exceptions import InvalidSourceOrTargetLanguage, TranslatorSetupFailure, MissingTranslatorParams, TranslatorNotValid
from ..textdetector.textblock import TextBlock
from ..moduleparamparser import ModuleParamParser
from utils.registry import Registry
from utils.io_utils import text_is_empty
import deepl

TRANSLATORS = Registry('translators')
register_translator = TRANSLATORS.register_module
@@ -41,24 +39,34 @@ SYSTEM_LANGMAP = {
}


def check_language_support(set_lang_method):
def check_language_support(check_type: str = 'source'):
    
    def decorator(set_lang_method):
        @functools.wraps(set_lang_method)
    def wrapper(self, lang: str):
        if not lang in self.lang_map or not self.lang_map[lang]:
            msg = '\n'.join(self.supported_languages())
            raise InvalidSourceOrTargetLanguage(lang, message=msg)
        def wrapper(self, lang: str = ''):
            if check_type == 'source':
                supported_lang_list = self.supported_src_list
            else:
                supported_lang_list = self.supported_tgt_list
            if not lang in supported_lang_list:
                msg = '\n'.join(supported_lang_list)
                raise InvalidSourceOrTargetLanguage(f'Invalid {check_type}: {lang}\n', message=msg)
            return set_lang_method(self, lang)
        return wrapper

    return decorator


class TranslatorBase(ModuleParamParser):

    concate_text = True
    
    def __init__(self,
                 lang_source: str = None, 
                 lang_target: str = None,
                 lang_source: str, 
                 lang_target: str,
                 raise_unsupported_lang: bool = True,
                 **setup_params) -> None:
        super().__init__(**setup_params)
        self.sys_lang = SYSTEM_LANGMAP[SYSTEM_LANG] if SYSTEM_LANG in SYSTEM_LANGMAP else ''
        self.name = ''
        for key in TRANSLATORS.module_dict:
            if TRANSLATORS.module_dict[key] == self.__class__:
@@ -77,23 +85,19 @@ class TranslatorBase(ModuleParamParser):
            else:
                raise TranslatorSetupFailure(e)

        if lang_source is None or lang_source == '':
            lang_source = 'Auto' if self.support_auto_souce() else self.default_lang()
        if lang_target is None or lang_target == '':
            lang_target = self.default_lang()
        self.valid_lang_list = [lang for lang in self.lang_map if self.lang_map[lang] != '']

        try:
            self.set_source(lang_source)
            self.set_target(lang_target)
        except InvalidSourceOrTargetLanguage as e:
            if raise_unsupported_lang:
                raise e
            else:
                lang_source = self.supported_src_list[0]
                lang_target = self.supported_tgt_list[0]
                self.set_source(lang_source)
                self.set_target(lang_target)

    def support_auto_souce(self):
        if self.lang_map['Auto']:
            return True
        return False

    def default_lang(self):
        if self.sys_lang in self.lang_map:
            return self.sys_lang
        return self.supported_languages()[0]

    def _setup_translator(self):
        raise NotImplementedError
@@ -101,11 +105,11 @@ class TranslatorBase(ModuleParamParser):
    def setup_translator(self):
        self._setup_translator()

    @check_language_support
    @check_language_support(check_type='source')
    def set_source(self, lang: str):
        self.lang_source = lang

    @check_language_support
    @check_language_support(check_type='target')
    def set_target(self, lang: str):
        self.lang_target = lang

@@ -151,13 +155,15 @@ class TranslatorBase(ModuleParamParser):
            blk.translation = tr

    def supported_languages(self) -> List[str]:
        return [lang for lang in self.lang_map if self.lang_map[lang]]
        return self.valid_lang_list

    def support_language(self, lang: str) -> bool:
        if lang in self.lang_map:
            if self.lang_map[lang]:
                return True
        return False
    @property
    def supported_tgt_list(self) -> List[str]:
        return self.valid_lang_list

    @property
    def supported_src_list(self) -> List[str]:
        return self.valid_lang_list


@register_translator('google')
@@ -346,6 +352,59 @@ class DeeplTranslator(TranslatorBase):
        return [i.text for i in result]
    


SUGOIMODEL_TRANSLATOR_DIRPATH = 'data/models/sugoi_translator'
SUGOIMODEL_TOKENIZATOR_PATH = SUGOIMODEL_TRANSLATOR_DIRPATH + "\\spm.ja.nopretok.model"
@register_translator('Sugoi')
class SugoiTranslator(TranslatorBase):

    concate_text = False
    setup_params: Dict = {
        'device': {
            'type': 'selector',
            'options': ['cpu', 'cuda'],
            'select': 'cpu'
        }
    }

    def _setup_translator(self):
        self.lang_map['日本語'] = 'ja'
        self.lang_map['English'] = 'en'
        
        self.translator = ctranslate2.Translator(SUGOIMODEL_TRANSLATOR_DIRPATH, device=self.setup_params['device']['select'])
        self.tokenizator = spm.SentencePieceProcessor(model_file=SUGOIMODEL_TOKENIZATOR_PATH)

    def _translate(self, text: Union[str, List]) -> Union[str, List]:
        input_is_lst = True
        if isinstance(text, str):
            text = [text]
            input_is_lst = False
        
        text = [i.replace(".", "@").replace("", "@") for i in text]
        tokenized_text = self.tokenizator.encode(text, out_type=str, enable_sampling=True, alpha=0.1, nbest_size=-1)
        tokenized_translated = self.translator.translate_batch(tokenized_text)
        text_translated = [''.join(text[0]["tokens"]).replace('', ' ').replace("@", ".") for text in tokenized_translated]
        
        if not input_is_lst:
            return text_translated[0]
        return text_translated

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
        if param_key == 'device':
            if hasattr(self, 'translator'):
                delattr(self, 'translator')
            self.translator = ctranslate2.Translator(SUGOIMODEL_TRANSLATOR_DIRPATH, device=self.setup_params['device']['select'])

    @property
    def supported_tgt_list(self) -> List[str]:
        return ['English']

    @property
    def supported_src_list(self) -> List[str]:
        return ['日本語']


# # "dummy translator" is the name showed in the app
# @register_translator('dummy translator')
# class DummyTranslator(TranslatorBase):
+32 −9
Original line number Diff line number Diff line
import sys, os
import os.path as osp
sys.path.append(osp.dirname(osp.dirname(__file__)))
from dl.translators import TranslatorBase, GoogleTranslator, PapagoTranslator, TRANSLATORS, CaiyunTranslator, DeeplTranslator
from dl.translators import *
from ui.constants import PROGRAM_PATH
os.chdir(PROGRAM_PATH)

def test_translator(translator: TranslatorBase, test_list):
def test_translator(translator: TranslatorBase, test_list: List):
    for test_dict in test_list:
        translator.set_source(test_dict['source'])
        translator.set_target(test_dict['target'])
        for text in test_dict['text_list']:
            print(f'src: {text}, translation: {translator.translate(text)}')
            translation = translator.translate(text)
            print(f'src: {text}, translation: {translation}')
            assert type(translation) == type(text)
            if isinstance(translation, List):
                assert len(translation) == len(text)

    text = ['', '', '', '', '', '', '']
    print(f'src: {text}, translation: {translator.translate(text)}')
    translation = translator.translate(text)
    assert len(translation) == len(text)
    print(f'src: {text}, translation: {translation}')
    text = ''
    print(f'src: {text}, translation: {translator.translate(text)}')
    translation = translator.translate(text)
    print(f'src: {text}, translation: {translation}')

engchscht_test_list = [
    {
@@ -41,14 +50,28 @@ engchscht_test_list = [
    }
]

jaeng_test_list = [
    {
        'source': '日本語',
        'target': 'English',
        'text_list': [
            '日本語のテスト',
            ['日本語の...テスト', 'ククク…何かしらねぇ 当ててごらんなさい']
        ]
    },
]

if __name__ == '__main__':

    device = 'cuda'

    caiyun_setup_params = {
        'token': 'invalidtoken',
    }
    # ctranslator = CaiyunTranslator('简体中文', 'English', **caiyun_setup_params)
    ptranslator = PapagoTranslator('简体中文', 'English')
    gtranslator = GoogleTranslator('简体中文', 'English')
    dtranslator = DeeplTranslator('简体中文', 'English')
    test_translator(ptranslator, engchscht_test_list)
    # ptranslator = PapagoTranslator('简体中文', 'English')
    # gtranslator = GoogleTranslator('简体中文', 'English')
    # dtranslator = DeeplTranslator('简体中文', 'English')
    sugoi_translator = SugoiTranslator('日本語', 'English', device= {'select': device})
    test_translator(sugoi_translator, jaeng_test_list)
+16 −20
Original line number Diff line number Diff line
@@ -158,18 +158,12 @@ class TranslateThread(ModuleThread):
            setup_params = self.dl_config.translator_setup_params[translator]
            translator_module: TranslatorBase = TRANSLATORS.module_dict[translator]
            if setup_params is not None:
                self.translator = translator_module(source, target, **setup_params)
                self.translator = translator_module(source, target, raise_unsupported_lang=False, **setup_params)
            else:
                self.translator = translator_module(source, target)
                self.translator = translator_module(source, target, raise_unsupported_lang=False)
            self.dl_config.translate_source = self.translator.lang_source
            self.dl_config.translate_target = self.translator.lang_target
            self.dl_config.translator = self.translator.name
        except InvalidSourceOrTargetLanguage as e:
            msg = self.tr('The selected language is not supported by ') + translator + '.\n'
            msg += self.tr('support list: ') + '\n'
            msg += e.message
            self.translator = old_translator
            self.exception_occurred.emit(msg, '', traceback.format_exc())
        except Exception as e:
            self.translator = old_translator
            msg = self.tr('Failed to set translator ') + translator
@@ -253,7 +247,6 @@ class ImgtransThread(QThread):
    exception_occurred = Signal(str, str)
    def __init__(self, 
                 dl_config: DLModuleConfig, 
                 imgtrans_proj: ProjImgTrans, 
                 textdetect_thread: TextDetectThread,
                 ocr_thread: OCRThread,
                 translate_thread: TranslateThread,
@@ -261,13 +254,13 @@ class ImgtransThread(QThread):
                 *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.dl_config = dl_config
        self.imgtrans_proj: ProjImgTrans = None
        self.textdetect_thread = textdetect_thread
        self.ocr_thread = ocr_thread
        self.translate_thread = translate_thread
        self.inpaint_thread = inpaint_thread
        self.job = None
        self.translate_mode = 1
        self.imgtrans_proj: ProjImgTrans = None


    @property
    def textdetector(self) -> TextDetectorBase:
@@ -296,8 +289,11 @@ class ImgtransThread(QThread):
        self.translate_counter = 0
        self.inpaint_counter = 0
        self.num_pages = num_pages = len(self.imgtrans_proj.pages)
        if self.dl_config.enable_translate and self.translate_mode == 1:

        self.parallel_trans = not self.translator.is_computational_intensive()
        if self.dl_config.enable_translate and self.parallel_trans:
            self.translate_thread.runTranslatePipeline(self.imgtrans_proj)

        for imgname in self.imgtrans_proj.pages:
            img = self.imgtrans_proj.read_img(imgname)

@@ -315,12 +311,12 @@ class ImgtransThread(QThread):

                if self.dl_config.enable_translate:
                    try:
                        if self.translate_mode == 0:
                        if self.parallel_trans:
                            self.translate_thread.push_pagekey_queue(imgname)
                        else:
                            self.translator.translate_textblk_lst(blk_list)
                            self.translate_counter += 1
                            self.update_translate_progress.emit(self.translate_counter)
                        else:
                            self.translate_thread.push_pagekey_queue(imgname)
                    except Exception as e:
                        self.dl_config.enable_translate = False
                        self.update_translate_progress.emit(num_pages)
@@ -347,7 +343,7 @@ class ImgtransThread(QThread):
            or not self.dl_config.enable_ocr \
            or not self.dl_config.enable_translate:
            return True
        if self.translate_mode == 1:
        if self.parallel_trans:
            return self.translate_thread.pipeline_finished()
        return self.translate_counter == self.num_pages

@@ -366,10 +362,10 @@ class ImgtransThread(QThread):
        if self.dl_config.enable_ocr:
            counter = min(counter, self.ocr_counter)
            if self.dl_config.enable_translate:
                if self.translate_mode == 0:
                    counter = min(counter, self.translate_counter)
                else:
                if self.parallel_trans:
                    counter = min(counter, self.translate_thread.finished_counter)
                else:
                    counter = min(counter, self.translate_counter)
                    
        if self.dl_config.enable_inpaint:
            counter = min(counter, self.inpaint_counter)
@@ -425,7 +421,7 @@ class DLManager(QObject):

        self.progress_msgbox = ProgressMessageBox()

        self.imgtrans_thread = ImgtransThread(dl_config, imgtrans_proj, self.textdetect_thread, self.ocr_thread, self.translate_thread, self.inpaint_thread)
        self.imgtrans_thread = ImgtransThread(dl_config, self.textdetect_thread, self.ocr_thread, self.translate_thread, self.inpaint_thread)
        self.imgtrans_thread.update_detect_progress.connect(self.on_update_detect_progress)
        self.imgtrans_thread.update_ocr_progress.connect(self.on_update_ocr_progress)
        self.imgtrans_thread.update_translate_progress.connect(self.on_update_translate_progress)
Loading