Unverified Commit 0e05f875 authored by PiDanShouRouZhouXD's avatar PiDanShouRouZhouXD Committed by GitHub
Browse files

Update trans_sakura.py (#614)

1. Adapt to Sakura 1.0 prompt and modify dictionary logic.
2. Fix the error when the dictionary path is empty.
parent 9332d2dc
Loading
Loading
Loading
Loading
+28 −32
Original line number Diff line number Diff line
@@ -64,12 +64,17 @@ class SakuraDict():
        self.logger = logger
        self.dict_str = ""
        self.version = version
        self.path = path

        if not path:
            return  # 如果路径为空,直接返回,不加载字典

        if not os.path.exists(path):
            if self.version != "0.9":
                self.logger.warning(f"字典文件不存在: {path}")
                self.logger.info(f"字典文件不存在: {path}\n 如果您不需要字典功能,请忽略此警告。")
            return
        self.path = path
        if self.version == "0.10":

        if self.version == "1.0":
            try:
                self.load_dict(path)
            except Exception as e:
@@ -89,7 +94,7 @@ class SakuraDict():
            字典文件路径

        """
        if self.version == "0.9":
        if self.version == "0.9" or not dic_path:
            return

        dic_type = self._detect_type(dic_path)
@@ -245,7 +250,7 @@ class SakuraDict():
            字典内容字符串

        """
        if self.version == "0.9":
        if self.version == "0.9" or not self.path:
            return ""

        if not self.dict_str:
@@ -270,7 +275,7 @@ class SakuraDict():
            字典内容字符串

        """
        if self.version == "0.9":
        if self.version == "0.9" or not self.path:
            return ""

        if not self.dict_str:
@@ -304,7 +309,7 @@ class SakuraDict():
            字典内容的JSON格式字符串

        """
        if self.version == "0.9":
        if self.version == "0.9" or not self.path:
            return ""

        if not self.dict_str:
@@ -338,7 +343,7 @@ class SakuraDict():
            字典类型,可选值有"sakura""galtransl""json",默认为"sakura"

        """
        if self.version == "0.9":
        if self.version == "0.9" or not self.path:
            return

        if dict_type == "sakura":
@@ -371,7 +376,7 @@ class SakuraTranslator(BaseTranslator):
            'type': 'selector',
            'options': [
                '0.9',
                '0.10',
                '1.0',
                'galtransl-v1'
            ],
            'value': '0.9'
@@ -385,8 +390,8 @@ class SakuraTranslator(BaseTranslator):
    _CHAT_SYSTEM_TEMPLATE_009 = (
        '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。'
    )
    _CHAT_SYSTEM_TEMPLATE_010 = (
        '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,注意不要擅自添加原文中没有的代词,也不要擅自增加或减少换行'
    _CHAT_SYSTEM_TEMPLATE_100 = (
        '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,擅自添加原文中没有的代词。'
    )

    _CHAT_SYSTEM_TEMPLATE_GALTRANSL_V1 = (
@@ -409,14 +414,6 @@ class SakuraTranslator(BaseTranslator):
    def max_tokens(self) -> int:
        return self.params['max tokens']

    @property
    def timeout(self) -> int:
        return self.params['timeout']

    @property
    def retry_attempts(self) -> int:
        return self.params['retry attempts']

    @property
    def api_base_raw(self) -> str:
        return self.params['api baseurl']
@@ -447,7 +444,7 @@ class SakuraTranslator(BaseTranslator):
        self._current_style = "precise"
        self._emoji_pattern = re.compile(r'[\U00010000-\U0010ffff]')
        self._heart_pattern = re.compile(r'')
        sakura_version = self.sakura_version if self.sakura_version!= 'galtransl-v1' else '0.10'
        sakura_version = self.sakura_version if self.sakura_version!= 'galtransl-v1' else '1.0'
        self.sakura_dict = SakuraDict(
            self.dict_path, self.logger, sakura_version)
        self.logger.info(f'当前选择的Sakura版本: {self.sakura_version}')
@@ -563,13 +560,13 @@ class SakuraTranslator(BaseTranslator):
            '将下面的日文文本翻译成中文:',
            prompt,
        ])
        prompt_010 = '\n'.join([
        prompt_100 = '\n'.join([
            'System:',
            self._CHAT_SYSTEM_TEMPLATE_010,
            self._CHAT_SYSTEM_TEMPLATE_100,
            'User:',
            "根据以下术语表:",
            "根据以下术语表(可以为空)",
            gpt_dict_raw_text,
            "将下面的日文文本根据上述术语表的对应关系和注翻译成中文:",
            "将下面的日文文本根据对应关系和注翻译成中文:",
            prompt,
        ])
        prompt_galtransl_v1 = '\n'.join([
@@ -583,8 +580,8 @@ class SakuraTranslator(BaseTranslator):
        ])
        if self.sakura_version == '0.9':
            return prompt_009
        elif self.sakura_version == '0.10':
            return prompt_010
        elif self.sakura_version == '1.0':
            return prompt_100
        else:
            return prompt_galtransl_v1

@@ -791,7 +788,8 @@ class SakuraTranslator(BaseTranslator):
            'num_beams': 1,
            'repetition_penalty': 1.0,
        }
        if self.sakura_version == "0.9":
        gpt_dict_raw_text = self.sakura_dict.get_dict_str_within_text(raw_text)
        if self.sakura_version == "0.9" or gpt_dict_raw_text == "":
            messages = [
                {
                    "role": "system",
@@ -802,22 +800,20 @@ class SakuraTranslator(BaseTranslator):
                    "content": f"将下面的日文文本翻译成中文:{raw_text}"
                }
            ]
        elif self.sakura_version == "0.10":
        elif self.sakura_version == "1.0":
            gpt_dict_raw_text = self.sakura_dict.get_dict_str_within_text(raw_text)
            self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}")
            messages = [
                {
                    "role": "system",
                    "content": f"{self._CHAT_SYSTEM_TEMPLATE_010}"
                    "content": f"{self._CHAT_SYSTEM_TEMPLATE_100}"
                },
                {
                    "role": "user",
                    "content": f"根据以下术语表:\n{gpt_dict_raw_text}\n将下面的日文文本根据上述术语表的对应关系和注翻译成中文:{raw_text}"
                    "content": f"根据以下术语表(可以为空)\n{gpt_dict_raw_text}\n将下面的日文文本根据对应关系和注翻译成中文:{raw_text}"
                }
            ]
        else:
            gpt_dict_raw_text = self.sakura_dict.get_dict_str_within_text(raw_text)
            self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}")
            messages = [
                {
                    "role": "system",