Commit cafe0948 authored by dmMaze's avatar dmMaze
Browse files

update translate module & sugoi_translator

parent ae87e648
Loading
Loading
Loading
Loading
+63 −37
Original line number Diff line number Diff line
import urllib.request
from typing import Dict, List, Union
from typing import Dict, List, Union, Set
import time, requests, re, uuid, base64, hmac, functools, json, deepl
import ctranslate2, sentencepiece as spm
from .exceptions import InvalidSourceOrTargetLanguage, TranslatorSetupFailure, MissingTranslatorParams, TranslatorNotValid
@@ -39,24 +39,34 @@ SYSTEM_LANGMAP = {
}


def check_language_support(set_lang_method):
def check_language_support(check_type: str = 'source'):
    
    def decorator(set_lang_method):
        @functools.wraps(set_lang_method)
    def wrapper(self, lang: str):
        if not lang in self.lang_map or not self.lang_map[lang]:
            msg = '\n'.join(self.supported_languages())
            raise InvalidSourceOrTargetLanguage(lang, message=msg)
        def wrapper(self, lang: str = ''):
            if check_type == 'source':
                supported_lang_list = self.supported_src_list
            else:
                supported_lang_list = self.supported_tgt_list
            if not lang in supported_lang_list:
                msg = '\n'.join(supported_lang_list)
                raise InvalidSourceOrTargetLanguage(f'Invalid {check_type}: {lang}\n', message=msg)
            return set_lang_method(self, lang)
        return wrapper

    return decorator


class TranslatorBase(ModuleParamParser):

    concate_text = True
    
    def __init__(self,
                 lang_source: str = None, 
                 lang_target: str = None,
                 lang_source: str, 
                 lang_target: str,
                 raise_unsupported_lang: bool = True,
                 **setup_params) -> None:
        super().__init__(**setup_params)
        self.sys_lang = SYSTEM_LANGMAP[SYSTEM_LANG] if SYSTEM_LANG in SYSTEM_LANGMAP else ''
        self.name = ''
        for key in TRANSLATORS.module_dict:
            if TRANSLATORS.module_dict[key] == self.__class__:
@@ -75,23 +85,19 @@ class TranslatorBase(ModuleParamParser):
            else:
                raise TranslatorSetupFailure(e)

        if lang_source is None or lang_source == '':
            lang_source = 'Auto' if self.support_auto_souce() else self.default_lang()
        if lang_target is None or lang_target == '':
            lang_target = self.default_lang()
        self.valid_lang_list = [lang for lang in self.lang_map if self.lang_map[lang] != '']

        try:
            self.set_source(lang_source)
            self.set_target(lang_target)
        except InvalidSourceOrTargetLanguage as e:
            if raise_unsupported_lang:
                raise e
            else:
                lang_source = self.supported_src_list[0]
                lang_target = self.supported_tgt_list[0]
                self.set_source(lang_source)
                self.set_target(lang_target)

    def support_auto_souce(self):
        if self.lang_map['Auto']:
            return True
        return False

    def default_lang(self):
        if self.sys_lang in self.lang_map:
            return self.sys_lang
        return self.supported_languages()[0]

    def _setup_translator(self):
        raise NotImplementedError
@@ -99,11 +105,11 @@ class TranslatorBase(ModuleParamParser):
    def setup_translator(self):
        self._setup_translator()

    @check_language_support
    @check_language_support(check_type='source')
    def set_source(self, lang: str):
        self.lang_source = lang

    @check_language_support
    @check_language_support(check_type='target')
    def set_target(self, lang: str):
        self.lang_target = lang

@@ -149,13 +155,15 @@ class TranslatorBase(ModuleParamParser):
            blk.translation = tr

    def supported_languages(self) -> List[str]:
        return [lang for lang in self.lang_map if self.lang_map[lang]]
        return self.valid_lang_list

    @property
    def supported_tgt_list(self) -> List[str]:
        return self.valid_lang_list

    def support_language(self, lang: str) -> bool:
        if lang in self.lang_map:
            if self.lang_map[lang]:
                return True
        return False
    @property
    def supported_src_list(self) -> List[str]:
        return self.valid_lang_list


@register_translator('google')
@@ -343,9 +351,10 @@ class DeeplTranslator(TranslatorBase):
        result = translator.translate_text(text, source_lang=source, target_lang=target)
        return [i.text for i in result]
    


SUGOIMODEL_TRANSLATOR_DIRPATH = 'data/models/sugoi_translator'
SUGOIMODEL_TOKENIZATOR_PATH = SUGOIMODEL_TRANSLATOR_DIRPATH + "\\spm.ja.nopretok.model"

@register_translator('Sugoi')
class SugoiTranslator(TranslatorBase):

@@ -366,10 +375,18 @@ class SugoiTranslator(TranslatorBase):
        self.tokenizator = spm.SentencePieceProcessor(model_file=SUGOIMODEL_TOKENIZATOR_PATH)

    def _translate(self, text: Union[str, List]) -> Union[str, List]:
        input_is_lst = True
        if isinstance(text, str):
            text = [text]
            input_is_lst = False
        
        text = [i.replace(".", "@").replace("", "@") for i in text]
        tokenized_text = self.tokenizator.encode(text, out_type=str, enable_sampling=True, alpha=0.1, nbest_size=-1)
        tokenized_translated = self.translator.translate_batch(tokenized_text)
        text_translated = [''.join(text[0]["tokens"]).replace('', ' ').replace("@", ".") for text in tokenized_translated]
        
        if not input_is_lst:
            return text_translated[0]
        return text_translated

    def updateParam(self, param_key: str, param_content):
@@ -379,6 +396,15 @@ class SugoiTranslator(TranslatorBase):
                delattr(self, 'translator')
            self.translator = ctranslate2.Translator(SUGOIMODEL_TRANSLATOR_DIRPATH, device=self.setup_params['device']['select'])

    @property
    def supported_tgt_list(self) -> List[str]:
        return ['English']

    @property
    def supported_src_list(self) -> List[str]:
        return ['日本語']


# # "dummy translator" is the name showed in the app
# @register_translator('dummy translator')
# class DummyTranslator(TranslatorBase):
+4 −11
Original line number Diff line number Diff line
@@ -158,18 +158,12 @@ class TranslateThread(ModuleThread):
            setup_params = self.dl_config.translator_setup_params[translator]
            translator_module: TranslatorBase = TRANSLATORS.module_dict[translator]
            if setup_params is not None:
                self.translator = translator_module(source, target, **setup_params)
                self.translator = translator_module(source, target, raise_unsupported_lang=False, **setup_params)
            else:
                self.translator = translator_module(source, target)
                self.translator = translator_module(source, target, raise_unsupported_lang=False)
            self.dl_config.translate_source = self.translator.lang_source
            self.dl_config.translate_target = self.translator.lang_target
            self.dl_config.translator = self.translator.name
        except InvalidSourceOrTargetLanguage as e:
            msg = self.tr('The selected language is not supported by ') + translator + '.\n'
            msg += self.tr('support list: ') + '\n'
            msg += e.message
            self.translator = old_translator
            self.exception_occurred.emit(msg, '', traceback.format_exc())
        except Exception as e:
            self.translator = old_translator
            msg = self.tr('Failed to set translator ') + translator
@@ -253,7 +247,6 @@ class ImgtransThread(QThread):
    exception_occurred = Signal(str, str)
    def __init__(self, 
                 dl_config: DLModuleConfig, 
                 imgtrans_proj: ProjImgTrans, 
                 textdetect_thread: TextDetectThread,
                 ocr_thread: OCRThread,
                 translate_thread: TranslateThread,
@@ -261,12 +254,12 @@ class ImgtransThread(QThread):
                 *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.dl_config = dl_config
        self.imgtrans_proj: ProjImgTrans = None
        self.textdetect_thread = textdetect_thread
        self.ocr_thread = ocr_thread
        self.translate_thread = translate_thread
        self.inpaint_thread = inpaint_thread
        self.job = None
        self.imgtrans_proj: ProjImgTrans = None
        self.translate_mode = 1

    @property
@@ -425,7 +418,7 @@ class DLManager(QObject):

        self.progress_msgbox = ProgressMessageBox()

        self.imgtrans_thread = ImgtransThread(dl_config, imgtrans_proj, self.textdetect_thread, self.ocr_thread, self.translate_thread, self.inpaint_thread)
        self.imgtrans_thread = ImgtransThread(dl_config, self.textdetect_thread, self.ocr_thread, self.translate_thread, self.inpaint_thread)
        self.imgtrans_thread.update_detect_progress.connect(self.on_update_detect_progress)
        self.imgtrans_thread.update_ocr_progress.connect(self.on_update_ocr_progress)
        self.imgtrans_thread.update_translate_progress.connect(self.on_update_translate_progress)
+2 −4
Original line number Diff line number Diff line
@@ -216,10 +216,8 @@ class TranslatorConfigPanel(ModuleConfigParseWidget):
        self.source_combobox.clear()
        self.target_combobox.clear()

        for lang in translator.lang_map:
            if translator.lang_map[lang] != '':
                self.source_combobox.addItem(lang)
                self.target_combobox.addItem(lang)
        self.source_combobox.addItems(translator.supported_src_list)
        self.target_combobox.addItems(translator.supported_tgt_list)
        self.translator_combobox.setCurrentText(translator.name)
        self.source_combobox.setCurrentText(translator.lang_source)
        self.target_combobox.setCurrentText(translator.lang_target)