Unverified Commit a4b1b4d0 authored by PiDanShouRouZhouXD's avatar PiDanShouRouZhouXD Committed by GitHub
Browse files

Merge pull request #486 from PiDanShouRouZhouXD/dev

Update Sakura translator
parents 046bc0e1 e8afab20
Loading
Loading
Loading
Loading
+306 −145
Original line number Diff line number Diff line
# 同步更新自manga-image-translator

from http import client
import logging
import re
import time
from token import OP
from typing import List, Dict, Union, Callable
import time
import os
import json

import openai

@@ -21,189 +20,310 @@ class InvalidNumTranslations(Exception):


class SakuraDict():
    """
    Sakura字典类,用于加载和管理Sakura字典。

    属性:
    --------
    logger : logging.Logger
        日志记录器对象
    dict_str : str
        字典内容字符串
    version : str
        Sakura字典版本号
    path : str
        字典文件路径

    方法:
    --------
    __init__(self, path: str, logger: logging.Logger, version: str = "0.9") -> None:
        初始化Sakura字典对象。
    load_dict(self, dic_path: str) -> None:
        根据字典类型加载字典。
    get_dict_str(self) -> str:
        获取字典内容字符串。
    save_dict_to_file(self, dic_path: str, dict_type: str = "sakura") -> None:
        将字典内容保存到文件。

    """

    def __init__(self, path: str, logger: logging.Logger, version: str = "0.9") -> None:
        """
        初始化Sakura字典对象。

        参数:
        --------
        path : str
            字典文件路径
        logger : logging.Logger
            日志记录器对象
        version : str, optional
            Sakura字典版本号,默认为"0.9"

        """
        self.logger = logger
        self.dict_str = ""
        self.version = version
        if not os.path.exists(path):
            if self.version == '0.10':
            self.logger.warning(f"字典文件不存在: {path}")
            return
        else:
        self.path = path
        if self.version == '0.10':
            self.dict_str = self.get_dict_from_file(path)
        if self.version == '0.9':
        if self.version == "0.10":
            try:
                self.load_dict(path)
            except Exception as e:
                self.logger.warning(f"载入字典失败: {e}")
        elif self.version == "0.9":
            self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表")

    def load_galtransl_dic(self, dic_path: str):
    def load_dict(self, dic_path: str) -> None:
        """
        载入Galtransl词典。
        根据字典类型加载字典。

        参数:
        --------
        dic_path : str
            字典文件路径

        """
        dic_type = self._detect_type(dic_path)
        if dic_type == "galtransl":
            self._load_galtransl_dic(dic_path)
        elif dic_type == "sakura":
            self._load_sakura_dict(dic_path)
        elif dic_type == "json":
            self._load_json_dict(dic_path)
        else:
            self.logger.warning(f"未知的字典类型: {dic_path}")

    def _load_galtransl_dic(self, dic_path: str) -> None:
        """
        加载Galtransl格式的字典。

        参数:
        --------
        dic_path : str
            字典文件路径

        """
        with open(dic_path, encoding="utf8") as f:
            dic_lines = f.readlines()
        if len(dic_lines) == 0:
        if not dic_lines:
            return
        dic_path = os.path.abspath(dic_path)
        dic_name = os.path.basename(dic_path)
        normalDic_count = 0

        gpt_dict = []
        for line in dic_lines:
            if line.startswith("\n"):
                continue
            elif line.startswith("\\\\") or line.startswith("//"):  # 注释行跳过
            if line.startswith(("\n", "\\\\", "//")):
                continue

            # 四个空格换成Tab
            line = line.replace("    ", "\t")

            sp = line.rstrip("\r\n").split("\t")  # 去多余换行符,Tab分割
            len_sp = len(sp)

            if len_sp < 2:  # 至少是2个元素
            sp = line.rstrip("\r\n").split("\t")
            if len(sp) < 2:
                continue
            src, dst, *info = sp
            gpt_dict.append(
                {"src": src, "dst": dst, "info": info[0] if info else None})
        gpt_dict_text_list = [
            f"{gpt['src']}->{gpt['dst']}{' #' + gpt['info'] if gpt['info'] else ''}" for gpt in gpt_dict]
        self.dict_str = "\n".join(gpt_dict_text_list)
        self.logger.info(f"载入 Galtransl 字典: {dic_name} {len(gpt_dict)}普通词条")

    def _load_sakura_dict(self, dic_path: str) -> None:
        """
        加载Sakura格式的字典。

            src = sp[0]
            dst = sp[1]
            info = sp[2] if len_sp > 2 else None
            gpt_dict.append({"src": src, "dst": dst, "info": info})
            normalDic_count += 1

        gpt_dict_text_list = []
        for gpt in gpt_dict:
            src = gpt['src']
            dst = gpt['dst']
            info = gpt['info'] if "info" in gpt.keys() else None
            if info:
                single = f"{src}->{dst} #{info}"
            else:
                single = f"{src}->{dst}"
            gpt_dict_text_list.append(single)

        gpt_dict_raw_text = "\n".join(gpt_dict_text_list)
        self.dict_str = gpt_dict_raw_text
        self.logger.info(
            f"载入 Galtransl 字典: {dic_name} {normalDic_count}普通词条"
        )
        参数:
        --------
        dic_path : str
            字典文件路径

    def load_sakura_dict(self, dic_path: str):
        """
        直接载入标准的Sakura字典。
        """

        with open(dic_path, encoding="utf8") as f:
            dic_lines = f.readlines()

        if len(dic_lines) == 0:
        if not dic_lines:
            return
        dic_path = os.path.abspath(dic_path)
        dic_name = os.path.basename(dic_path)
        normalDic_count = 0

        gpt_dict_text_list = []
        for line in dic_lines:
            if line.startswith("\n"):
            if line.startswith(("\n", "\\\\", "//")):
                continue
            elif line.startswith("\\\\") or line.startswith("//"):  # 注释行跳过
            sp = line.rstrip("\r\n").split("->")
            if len(sp) < 2:
                continue
            src, dst_info = sp
            dst_info_sp = dst_info.split("#")
            dst = dst_info_sp[0].strip()
            info = dst_info_sp[1].strip() if len(dst_info_sp) > 1 else None
            gpt_dict_text_list.append(
                f"{src}->{dst}{' #' + info if info else ''}")
        self.dict_str = "\n".join(gpt_dict_text_list)
        self.logger.info(
            f"载入标准Sakura字典: {dic_name} {len(gpt_dict_text_list)}普通词条")

            sp = line.rstrip("\r\n").split("->")  # 去多余换行符,->分割
            len_sp = len(sp)
    def _load_json_dict(self, dic_path: str) -> None:
        """
        加载JSON格式的字典。

            if len_sp < 2:  # 至少是2个元素
        参数:
        --------
        dic_path : str
            字典文件路径

        """
        with open(dic_path, encoding="utf8") as f:
            dic_json = json.load(f)
        if not dic_json:
            return
        dic_name = os.path.basename(dic_path)
        gpt_dict_text_list = []
        for item in dic_json:
            if not item:
                continue
            src = item.get("src", "")
            dst = item.get("dst", "")
            info = item.get("info", "")
            gpt_dict_text_list.append(
                f"{src}->{dst}{' #' + info if info else ''}")
        self.dict_str = "\n".join(gpt_dict_text_list)
        self.logger.info(f"载入JSON字典: {dic_name} {len(gpt_dict_text_list)}条记录")

    def _detect_type(self, dic_path: str) -> str:
        """
        检测字典文件的类型。

            src = sp[0]
            dst_info = sp[1].split("#")  # 使用#分割目标和信息
            dst = dst_info[0].strip()
            info = dst_info[1].strip() if len(dst_info) > 1 else None
            if info:
                single = f"{src}->{dst} #{info}"
            else:
                single = f"{src}->{dst}"
            gpt_dict_text_list.append(single)
            normalDic_count += 1
        参数:
        --------
        dic_path : str
            字典文件路径

        gpt_dict_raw_text = "\n".join(gpt_dict_text_list)
        self.dict_str = gpt_dict_raw_text
        self.logger.info(
            f"载入标准Sakura字典: {dic_name} {normalDic_count}普通词条"
        )
        返回:
        --------
        str
            字典类型,可能的值有"galtransl""sakura""json""unknown"

    def detect_type(self, dic_path: str):
        """
        检测字典类型。
        """
        with open(dic_path, encoding="utf8") as f:
            dic_lines = f.readlines()
        self.logger.debug(f"检测字典类型: {dic_path}")
        if len(dic_lines) == 0:
        if not dic_lines:
            return "unknown"

        # 判断是否为Galtransl字典
        is_galtransl = True
        if dic_path.endswith(".json"):
            return "json"
        for line in dic_lines:
            if line.startswith("\n"):
                continue
            elif line.startswith("\\\\") or line.startswith("//"):
            if line.startswith(("\n", "\\\\", "//")):
                continue
            if "\t" in line or "    " in line:
                return "galtransl"
            elif "->" in line:
                return "sakura"
        return "unknown"

            if "\t" not in line and "    " not in line:
                is_galtransl = False
                break
    def get_dict_str(self) -> str:
        """
        获取字典内容字符串。

        if is_galtransl:
            return "galtransl"
        返回:
        --------
        str
            字典内容字符串

        # 判断是否为Sakura字典
        is_sakura = True
        for line in dic_lines:
            if line.startswith("\n"):
                continue
            elif line.startswith("\\\\") or line.startswith("//"):
                continue
        """
        if not self.dict_str:
            try:
                self.load_dict(self.path)
            except Exception as e:
                self.logger.warning(f"载入字典失败: {e}")
        return self.dict_str
    
            if "->" not in line:
                is_sakura = False
                break
    def get_dict_str_within_text(self, text: str) -> str:
        """
        获取字典内容字符串,仅保留字典中出现的词条。

        if is_sakura:
            return "sakura"
        参数:
        --------
        text : str
            待翻译文本

        return "unknown"
        返回:
        --------
        str
            字典内容字符串

    def get_dict_str(self):
        """
        获取字典内容。
        """
        if self.version == '0.9':
            self.logger.info("您当前选择了Sakura 0.9版本,暂不支持术语表")
            return ""
        if self.dict_str == "":
        if not self.dict_str:
            try:
                self.dict_str = self.get_dict_from_file(self.path)
                return self.dict_str
                self.load_dict(self.path)
            except Exception as e:
                if self.version == '0.10':
                self.logger.warning(f"载入字典失败: {e}")
                return ""
        return self.dict_str

    def get_dict_from_file(self, dic_path: str):
        # 初始化一个空列表用于存储匹配的字典行
        matched_dict_lines = []

        # 遍历字典中的每一行
        for line in self.dict_str.splitlines():
            if '->' in line:
                src = line.split('->')[0]
                # 检查 src 是否在输入文本中
                if src in text:
                    matched_dict_lines.append(line)

        # 将匹配的字典行拼接成一个字符串并返回
        return '\n'.join(matched_dict_lines)

    def dict_to_json(self) -> str:
        """
        从文件载入字典。
        将字典内容转换为JSON格式。

        返回:
        --------
        str
            字典内容的JSON格式字符串

        """
        dic_type = self.detect_type(dic_path)
        if dic_type == "galtransl":
            self.load_galtransl_dic(dic_path)
        elif dic_type == "sakura":
            self.load_sakura_dict(dic_path)
        else:
            self.logger.warning(f"未知的字典类型: {dic_path}")
        return self.get_dict_str()
        if not self.dict_str:
            try:
                self.load_dict(self.path)
            except Exception as e:
                self.logger.warning(f"载入字典失败: {e}")
        dict_json = []
        for line in self.dict_str.split("\n"):
            if not line:
                continue
            sp = line.split("->")
            if len(sp) < 2:
                continue
            src, dst_info = sp
            dst_info_sp = dst_info.split("#")
            dst = dst_info_sp[0].strip()
            info = dst_info_sp[1].strip() if len(dst_info_sp) > 1 else None
            dict_json.append({"src": src, "dst": dst, "info": info})
        return json.dumps(dict_json, ensure_ascii=False, indent=4)

    def save_dict_to_file(self, dic_path: str, dict_type: str = "sakura") -> None:
        """
        将字典内容保存到文件。

        参数:
        --------
        dic_path : str
            字典文件保存路径
        dict_type : str, optional
            字典类型,可选值有"sakura""galtransl""json",默认为"sakura"

        """
        if dict_type == "sakura":
            with open(dic_path, "w", encoding="utf8") as f:
                f.write(self.dict_str)
        elif dict_type == "galtransl":
            with open(dic_path, "w", encoding="utf8") as f:
                f.write(self.dict_str.replace(
                    "->", "    ").replace(" #", "    "))
        elif dict_type == "json":
            json_data = self.dict_to_json()
            with open(dic_path, "w", encoding="utf8") as f:
                json.dump(json_data, f, ensure_ascii=False, indent=4)
        else:
            self.logger.warning(f"未知的字典类型: {dict_type}")

@register_translator('Sakura')
class SakuraTranslator(BaseTranslator):
@@ -216,20 +336,17 @@ class SakuraTranslator(BaseTranslator):
            'type': 'selector',
            'options': [
                '0.9',
                '0.10'
                '0.10',
                'galtransl-v1'
            ],
            'value': '0.9'
        },
        'retry attempts': 3,
        'timeout': 999,
        'max tokens': 1024,
        'repeat detect threshold': 20,
    }

    _TIMEOUT = 999  # 等待服务器响应的超时时间(秒)
    _TIMEOUT_RETRY_ATTEMPTS = 3  # 请求超时时的重试次数
    _RATELIMIT_RETRY_ATTEMPTS = 3  # 请求被限速时的重试次数
    _REPEAT_DETECT_THRESHOLD = 20  # 重复检测的阈值

    _CHAT_SYSTEM_TEMPLATE_009 = (
        '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。'
    )
@@ -237,6 +354,22 @@ class SakuraTranslator(BaseTranslator):
        '你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,注意不要擅自添加原文中没有的代词,也不要擅自增加或减少换行。'
    )

    _CHAT_SYSTEM_TEMPLATE_GALTRANSL_V1 = (
        '你是一个视觉小说翻译模型,可以通顺地使用给定的术语表以指定的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,注意不要混淆使役态和被动态的主语和宾语,不要擅自添加原文中没有的代词,也不要擅自增加或减少换行。'
    )

    @property 
    def timeout(self) -> int:
        return self.params['timeout']
    
    @property
    def retry_attempts(self) -> int:
        return self.params['retry attempts']
    
    @property
    def repeat_detect_threshold(self) -> int:
        return self.params['repeat detect threshold']

    @property
    def max_tokens(self) -> int:
        return self.params['max tokens']
@@ -279,8 +412,9 @@ class SakuraTranslator(BaseTranslator):
        self._current_style = "precise"
        self._emoji_pattern = re.compile(r'[\U00010000-\U0010ffff]')
        self._heart_pattern = re.compile(r'')
        sakura_version = self.sakura_version if self.sakura_version!= 'galtransl-v1' else '0.10'
        self.sakura_dict = SakuraDict(
            self.dict_path, self.logger, self.sakura_version)
            self.dict_path, self.logger, sakura_version)
        self.logger.info(f'当前选择的Sakura版本: {self.sakura_version}')

    def updateParam(self, param_key: str, param_content):
@@ -386,7 +520,7 @@ class SakuraTranslator(BaseTranslator):
        return repeated, s, longest_count, longest_pattern, actual_threshold

    def _format_prompt_log(self, prompt: str) -> str:
        gpt_dict_raw_text = self.sakura_dict.get_dict_str()
        gpt_dict_raw_text = self.sakura_dict.get_dict_str_within_text(prompt)
        prompt_009 = '\n'.join([
            'System:',
            self._CHAT_SYSTEM_TEMPLATE_009,
@@ -403,7 +537,21 @@ class SakuraTranslator(BaseTranslator):
            "将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:",
            prompt,
        ])
        return prompt_009 if self.sakura_version == '0.9' else prompt_010
        prompt_galtransl_v1 = '\n'.join([
            'System:',
            self._CHAT_SYSTEM_TEMPLATE_GALTRANSL_V1,
            'User:',
            "根据以下术语表:",
            gpt_dict_raw_text,
            "将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:",
            prompt,
        ])
        if self.sakura_version == '0.9':
            return prompt_009
        elif self.sakura_version == '0.10':
            return prompt_010
        else:
            return prompt_galtransl_v1

    def _split_text(self, text: str) -> List[str]:
        """
@@ -440,13 +588,13 @@ class SakuraTranslator(BaseTranslator):
            return None

        # 检查请求内容是否含有超过默认阈值的重复内容
        if self.detect_and_calculate_repeats(''.join(queries), self._REPEAT_DETECT_THRESHOLD)[0]:
        if self.detect_and_calculate_repeats(''.join(queries), self.repeat_detect_threshold)[0]:
            self.logger.warning(
                f'请求内容本身含有超过默认阈值{self._REPEAT_DETECT_THRESHOLD}的重复内容。')
                f'请求内容本身含有超过默认阈值{self.repeat_detect_threshold}的重复内容。')

        # 根据译文众数和默认阈值计算实际阈值
        actual_threshold = max(max(self.detect_and_calculate_repeats(
            query)[4] for query in queries), self._REPEAT_DETECT_THRESHOLD)
            query)[4] for query in queries), self.repeat_detect_threshold)

        if self.detect_and_calculate_repeats(response, actual_threshold)[0]:
            response = _retry_translation(queries, lambda r: self.detect_and_calculate_repeats(
@@ -532,7 +680,7 @@ class SakuraTranslator(BaseTranslator):
                    break
                except openai.RateLimitError:
                    ratelimit_attempt += 1
                    if ratelimit_attempt >= self._RATELIMIT_RETRY_ATTEMPTS:
                    if ratelimit_attempt >= self.retry_attempts:
                        raise
                    self.logger.warning(
                        f'Sakura因被限速而进行重试。尝试次数: {ratelimit_attempt}')
@@ -552,7 +700,7 @@ class SakuraTranslator(BaseTranslator):
                    time.sleep(30)
                except TimeoutError:
                    timeout_attempt += 1
                    if timeout_attempt >= self._TIMEOUT_RETRY_ATTEMPTS:
                    if timeout_attempt >= self.retry_attempts:
                        raise Exception('Sakura超时。')
                    self.logger.warning(
                        f'Sakura因超时而进行重试。尝试次数: {timeout_attempt}')
@@ -562,7 +710,7 @@ class SakuraTranslator(BaseTranslator):
                    break
                except openai.error.RateLimitError:
                    ratelimit_attempt += 1
                    if ratelimit_attempt >= self._RATELIMIT_RETRY_ATTEMPTS:
                    if ratelimit_attempt >= self.retry_attempts:
                        raise
                    self.logger.warning(
                        f'Sakura因被限速而进行重试。尝试次数: {ratelimit_attempt}')
@@ -591,7 +739,7 @@ class SakuraTranslator(BaseTranslator):
                    time.sleep(30)
                except TimeoutError:
                    timeout_attempt += 1
                    if timeout_attempt >= self._TIMEOUT_RETRY_ATTEMPTS:
                    if timeout_attempt >= self.retry_attempts:
                        raise Exception('Sakura超时。')
                    self.logger.warning(
                        f'Sakura因超时而进行重试。尝试次数: {timeout_attempt}')
@@ -619,8 +767,8 @@ class SakuraTranslator(BaseTranslator):
                    "content": f"将下面的日文文本翻译成中文:{raw_text}"
                }
            ]
        else:
            gpt_dict_raw_text = self.sakura_dict.get_dict_str()
        elif self.sakura_version == "0.10":
            gpt_dict_raw_text = self.sakura_dict.get_dict_str_within_text(raw_text)
            self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}")
            messages = [
                {
@@ -632,6 +780,19 @@ class SakuraTranslator(BaseTranslator):
                    "content": f"根据以下术语表:\n{gpt_dict_raw_text}\n将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}"
                }
            ]
        else:
            gpt_dict_raw_text = self.sakura_dict.get_dict_str_within_text(raw_text)
            self.logger.debug(f"Sakura Dict: {gpt_dict_raw_text}")
            messages = [
                {
                    "role": "system",
                    "content": f"{self._CHAT_SYSTEM_TEMPLATE_GALTRANSL_V1}"
                },
                {
                    "role": "user",
                    "content": f"根据以下术语表:\n{gpt_dict_raw_text}\n将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{raw_text}"
                }
            ]
        if OPENAPI_V1_API:
            client = openai.Client(
                api_key="sk-114514",