Commit b39fbfae authored by PiDanShouRouZhouXD's avatar PiDanShouRouZhouXD
Browse files

update sakura, add force apply dict

parent 0e05f875
Loading
Loading
Loading
Loading
+19 −6
Original line number Diff line number Diff line
@@ -107,6 +107,8 @@ class SakuraDict():
        else:
            self.logger.warning(f"未知的字典类型: {dic_path}")

        self.logger.debug(f"字典内容(转换后): {self.dict_str[:100]}")

    def _load_galtransl_dic(self, dic_path: str) -> None:
        """
        加载Galtransl格式的字典。
@@ -260,7 +262,7 @@ class SakuraDict():
                self.logger.warning(f"载入字典失败: {e}")
        return self.dict_str
    
    def get_dict_str_within_text(self, text: str) -> str:
    def get_dict_str_within_text(self, text: str, force_apply_dict: bool = False) -> str:
        """
        获取字典内容字符串,仅保留字典中出现的词条。

@@ -275,6 +277,8 @@ class SakuraDict():
            字典内容字符串

        """
        if force_apply_dict:
            return self.get_dict_str()
        if self.version == "0.9" or not self.path:
            return ""

@@ -293,7 +297,9 @@ class SakuraDict():
            if '->' in line:
                src = line.split('->')[0]
                # 检查 src 是否在输入文本中
                # self.logger.debug(f"检查字典原文{src}是否在文本{text}中")
                if src in text:
                    # self.logger.debug(f"匹配到字典行: {line}")
                    matched_dict_lines.append(line)

        # 将匹配的字典行拼接成一个字符串并返回
@@ -385,6 +391,11 @@ class SakuraTranslator(BaseTranslator):
        'timeout': 999,
        'max tokens': 1024,
        'repeat detect threshold': 20,
        'force apply dict': {
            'value': False,
            'description': 'Force apply the dictionary regardless of whether the terms appear in the original text \n DO NOT CHECK THIS IF YOU ARE NOT SURE WHAT IT MEANS',
            'type': 'checkbox',
        },
    }

    _CHAT_SYSTEM_TEMPLATE_009 = (
@@ -435,6 +446,10 @@ class SakuraTranslator(BaseTranslator):
    def dict_path(self) -> str:
        return self.params['dict path']

    @property
    def force_apply_dict(self) -> bool:
        return self.params['force apply dict']['value']

    def _setup_translator(self):
        self.lang_map['简体中文'] = 'Simplified Chinese'
        self.lang_map['日本語'] = 'Japanese'
@@ -552,7 +567,7 @@ class SakuraTranslator(BaseTranslator):
        return repeated, s, longest_count, longest_pattern, actual_threshold

    def _format_prompt_log(self, prompt: str) -> str:
        gpt_dict_raw_text = self.sakura_dict.get_dict_str_within_text(prompt)
        gpt_dict_raw_text = self.sakura_dict.get_dict_str_within_text(prompt, self.force_apply_dict)
        prompt_009 = '\n'.join([
            'System:',
            self._CHAT_SYSTEM_TEMPLATE_009,
@@ -788,7 +803,7 @@ class SakuraTranslator(BaseTranslator):
            'num_beams': 1,
            'repetition_penalty': 1.0,
        }
        gpt_dict_raw_text = self.sakura_dict.get_dict_str_within_text(raw_text)
        gpt_dict_raw_text = self.sakura_dict.get_dict_str_within_text(raw_text, self.force_apply_dict)
        if self.sakura_version == "0.9" or gpt_dict_raw_text == "":
            messages = [
                {
@@ -801,7 +816,6 @@ class SakuraTranslator(BaseTranslator):
                }
            ]
        elif self.sakura_version == "1.0":
            gpt_dict_raw_text = self.sakura_dict.get_dict_str_within_text(raw_text)
            messages = [
                {
                    "role": "system",
@@ -813,7 +827,6 @@ class SakuraTranslator(BaseTranslator):
                }
            ]
        else:
            gpt_dict_raw_text = self.sakura_dict.get_dict_str_within_text(raw_text)
            messages = [
                {
                    "role": "system",