Unverified Commit 7c71844b authored by PiDanShouRouZhouXD's avatar PiDanShouRouZhouXD Committed by GitHub
Browse files

Merge pull request #466 from PiDanShouRouZhouXD/dev

update stariver to fix #461 #456
parents 99e663f0 23e0f845
Loading
Loading
Loading
Loading
+5 −4
Original line number Diff line number Diff line
@@ -168,16 +168,17 @@ Sugoi 翻译器作者: [mingshiba](https://www.patreon.com/mingshiba)
  
### 文本检测
 * 暂时仅支持日文(方块字都差不多)和英文检测,训练代码和说明见https://github.com/dmMaze/comic-text-detector
 * 支持使用 [星河云(团子漫画OCR)](https://cloud.stariver.org.cn/)的字体检测,需要获取并填写token
   * 参数设置、token获取方式详[团子OCR说明](doc/团子OCR说明.md)
 * 支持使用 [星河云(团子漫画OCR)](https://cloud.stariver.org.cn/)的字体检测,需要填写用户名和密码,每次启动时会自动登录。
   * 详细说明[团子OCR说明](doc/团子OCR说明.md)


### OCR
 * 所有 mit 模型来自 manga-image-translator,支持日英汉识别和颜色提取
 * [manga_ocr](https://github.com/kha-white/manga-ocr) 来自 [kha-white](https://github.com/kha-white),支持日语识别,注意选用该模型程序不会提取颜色
 * 支持使用 [星河云(团子漫画OCR)](https://cloud.stariver.org.cn/)的OCR,需要获取并填写token
   * 参数设置、token获取方式详见 [团子OCR说明](doc/团子OCR说明.md)
 * 支持使用 [星河云(团子漫画OCR)](https://cloud.stariver.org.cn/)的OCR,需要填写用户名和密码,每次启动时会自动登录。
   * 目前的实现方案是逐个textblock进行OCR,速度较慢,准确度没有明显提升,不推荐使用。如果有需要,请使用团子Detector。
   * 推荐文本检测设置为团子Detector时,将OCR设为none_ocr,直接读取文本,节省时间和请求次数。
   * 详细说明见 [团子OCR说明](doc/团子OCR说明.md)


### 图像修复
+6 −5
Original line number Diff line number Diff line
@@ -204,15 +204,16 @@ This project is heavily dependent upon [manga-image-translator](https://github.c
  
## Text detection
 * Support English and Japanese text detection, training code and more details can be found at [comic-text-detector](https://github.com/dmMaze/comic-text-detector)
* Support using text detection from [Stariver Cloud (Tuanzi Comics OCR)](https://cloud.stariver.org.cn/), requires obtaining and filling in the token
   * For parameter settings and how to obtain the token, refer to [Tuanzi OCR Instructions (Chinese only)](doc/团子OCR说明.md)
* Support using text detection from [Starriver Cloud (Tuanzi Manga OCR)](https://cloud.stariver.org.cn/). Username and password need to be filled in, and automatic login will be performed each time the program is launched.

   * For detailed instructions, see [Tuanzi OCR Instructions (Chinese only)](doc/Tuanzi_OCR_Instructions.md)
## OCR
 * All mit* models are from manga-image-translator, support English, Japanese and Korean recognition and text color extraction.
 * [manga_ocr](https://github.com/kha-white/manga-ocr) is from [kha-white](https://github.com/kha-white), text recognition for Japanese, with the main focus being Japanese manga.
* Support using OCR from [Stariver Cloud (Tuanzi Comics OCR)](https://cloud.stariver.org.cn/), requires obtaining and filling in the token
   * For parameter settings and how to obtain the token, refer to [Tuanzi OCR Instructions (Chinese only)](doc/团子OCR说明.md)
   * When setting the text detection to Tuanzi Detector, it is recommended to set OCR to none_ocr, directly read the text, saving time and number of requests.
 * Support using OCR from [Starriver Cloud (Tuanzi Manga OCR)](https://cloud.stariver.org.cn/). Username and password need to be filled in, and automatic login will be performed each time the program is launched.
   * The current implementation uses OCR on each textblock individually, resulting in slower speed and no significant improvement in accuracy. It is not recommended. If needed, please use the Tuanzi Detector instead.
   * When using the Tuanzi Detector for text detection, it is recommended to set OCR to none_ocr to directly read the text, saving time and reducing the number of requests.
   * For detailed instructions, see [Tuanzi OCR Instructions (Chinese only)](doc/Tuanzi_OCR_Instructions.md)

## Inpainting
  * AOT is from [manga-image-translator](https://github.com/zyddnys/manga-image-translator).
+7 −31
Original line number Diff line number Diff line
@@ -4,36 +4,11 @@

</p>

## Token 获取方法
## 团子OCR说明

### 方法1:从cookies中获取token
### 登录
第一次登录时可能会提示密码出错等问题,可以在确认正确输入后勾选并取消勾选`force_refresh_token`选项,以重新登陆。保存后即可正常使用。

在浏览器中登录并访问[星河云OCR](https://cloud.stariver.org.cn/),在浏览器的开发者工具中查看`cookie`,其中包含`token`字段,复制其值。
<p align = "center">
<img src="https://github.com/PiDanShouRouZhouXD/BallonsTranslator/assets/38401147/ae2cbcec-b426-4396-a484-62aa09f22cf6" width="50%" height="50%">

</p>

### 方法2:通过API获取token

通过API获取token的方法如下:

```
POST https://capiv1.ap-sh.starivercs.cn/OCR/Admin/Login


Request Body:
{
    "User": "your_username",
    "Password": "your_password"
}

Response Body:
{

    "Token": "your_token"

}
```

其中,`User``Password`为登录团子OCR的用户名和密码,`Token`为登录成功后返回的token。
### 文本检测
文本检测功能也会提取出文字,而且是整体识别提取。所以当有使用团子的需求时,推荐不要单独使用OCR功能,而是使用团子的文本检测与none_ocr。
团子有自带的拟声词过滤等功能,详细参数设置请参考上方的`官方提供的请求参数参考`
 No newline at end of file
+70 −13
Original line number Diff line number Diff line
@@ -11,8 +11,12 @@ from .base import register_OCR, OCRBase, TextBlock
@register_OCR('stariver_ocr')
class OCRStariver(OCRBase):
    params = {
        'url': 'https://dl.ap-sh.starivercs.cn/v2/manga_trans/advanced/manga_ocr',
        'token': 'Replace with your token',
        'User': "填入你的用户名",
        'Password': "填入你的密码。请注意,密码会明文保存,请勿在公共电脑上使用",
        'force_refresh_token': {
            'type': 'checkbox',
            'value': False
        },
        "refine":{
            'type': 'checkbox',
            'value': True
@@ -40,8 +44,16 @@ class OCRStariver(OCRBase):
    }

    @property
    def token(self):
        return self.params['token']
    def User(self):
        return self.params['User']
    
    @property
    def Password(self):
        return self.params['Password']
    
    @property
    def force_refresh_token(self):
        return self.params['force_refresh_token']['value']
    
    @property
    def expand_ratio(self):
@@ -71,7 +83,32 @@ class OCRStariver(OCRBase):
    def force_expand(self):
        return self.params['force_expand']['value']
    
    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.url = 'https://dl.ap-sh.starivercs.cn/v2/manga_trans/advanced/manga_ocr'
        self.debug = False
        self.token = ''
        self.token_obtained = False
        # 初始化时设置用户名和密码为空
        self.register_username = None
        self.register_password = None


    def get_token(self):
        response = requests.post('https://capiv1.ap-sh.starivercs.cn/OCR/Admin/Login', json={
            "User": self.User,
            "Password": self.Password
        }).json()
        if response.get('Status', -1) != "Success":
            self.logger.error(f'stariver ocr 登录失败,错误信息:{response.get("ErrorMsg", "")}')
        token = response.get('Token', '')
        if token != '':
            self.logger.info(f'登录成功,token前10位:{token[:10]}')

        return token

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        self.update_token_if_needed() # 在向服务器发送请求前尝试更新 Token
        im_h, im_w = img.shape[:2]
        for blk in blk_list:
            x1, y1, x2, y2 = blk.xyxy
@@ -83,16 +120,14 @@ class OCRStariver(OCRBase):
                blk.text = ['']

    def ocr_img(self, img: np.ndarray) -> str:
        self.update_token_if_needed() # 在向服务器发送请求前尝试更新 Token
        self.logger.debug(f'ocr_img: {img.shape}')
        return self.ocr(img)

    def ocr(self, img: np.ndarray) -> str:
        
        if not self.params['token'] or self.params['token'] == 'Replace with your token':
            raise ValueError('token 没有设置。')
        
        payload = {
            "token": self.params['token'],
            "token": self.token,
            "mask": False,
            "refine": self.refine,
            "filtrate": self.filtrate,
@@ -110,19 +145,19 @@ class OCRStariver(OCRBase):
        response = requests.post(self.url, data=json.dumps(payload))

        if response.status_code != 200:
            print(f'请求失败,状态码:{response.status_code}')
            print(f'stariver ocr 请求失败,状态码:{response.status_code}')
            if response.json().get('Code', -1) != 0:
                print(f'错误信息:{response.json().get("Message", "")}')
                print(f'stariver ocr 错误信息:{response.json().get("Message", "")}')
                with open('stariver_ocr_error.txt', 'w', encoding='utf-8') as f:
                    f.write(response.text)
            raise ValueError('请求失败。')
            raise ValueError('stariver ocr 请求失败。')

        response_data = response.json()['Data']

        if self.debug:
            id = response.json().get('RequestID', '')
            file_name = f"stariver_ocr_response_{id}.json"
            print(f"请求成功,响应数据已保存至{file_name}")
            print(f"stariver ocr 请求成功,响应数据已保存至{file_name}")
            with open(file_name, 'w', encoding='utf-8') as f:
                json.dump(response_data, f, ensure_ascii=False, indent=4)

@@ -130,3 +165,25 @@ class OCRStariver(OCRBase):
                      for block in response_data.get('text_block', [])]
        texts_str = "".join(texts_list).replace('<skip>', '')
        return texts_str

    def update_token_if_needed(self):
        if (self.User != self.register_username or 
            self.Password != self.register_password):
            if self.token_obtained == False:
                if "填入你的用户名" not in self.User and "填入你的密码。请注意,密码会明文保存,请勿在公共电脑上使用" not in self.Password:
                    if len(self.Password) > 7 and len(self.User) >= 1:
                        new_token = self.get_token()
                        if new_token:  # 确保新获取到有效token再更新信息
                            self.token = new_token
                            self.register_username = self.User
                            self.register_password = self.Password
                            self.token_obtained = True
                            self.logger.info("Token updated due to credential change.")

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
        if param_key == 'force_refresh_token':
            self.token_obtained = False  # 强制刷新token时,将标志位设置为False
            self.token = ''  # 强制刷新token时,将token置空
            self.register_username = None  # 强制刷新token时,将用户名置空
            self.register_password = None  # 强制刷新token时,将密码置空
 No newline at end of file
+130 −27
Original line number Diff line number Diff line
@@ -11,7 +11,12 @@ from .base import register_textdetectors, TextDetectorBase, TextBlock
class StariverDetector(TextDetectorBase):

    params = {
        'token': "Replace with your token",
        'User': "填入你的用户名",
        'Password': "填入你的密码。请注意,密码会明文保存,请勿在公共电脑上使用",
        'force_refresh_token': {
            'type': 'checkbox',
            'value': False
        },
        'expand_ratio': "0.01",
        "refine": {
            'type': 'checkbox',
@@ -42,8 +47,16 @@ class StariverDetector(TextDetectorBase):
    }

    @property
    def token(self):
        return self.params['token']
    def User(self):
        return self.params['User']

    @property
    def Password(self):
        return self.params['Password']

    @property
    def force_refresh_token(self):
        return self.params['force_refresh_token']['value']

    @property
    def expand_ratio(self):
@@ -93,7 +106,25 @@ class StariverDetector(TextDetectorBase):
        super().__init__(**params)
        self.url = 'https://dl.ap-sh.starivercs.cn/v2/manga_trans/advanced/manga_ocr'
        self.debug = False
        # self.name = 'StariverDetector'
        self.token = ''
        self.token_obtained = False
        # 初始化时设置用户名和密码为空
        self.register_username = None
        self.register_password = None

    def get_token(self):
        response = requests.post('https://capiv1.ap-sh.starivercs.cn/OCR/Admin/Login', json={
            "User": self.User,
            "Password": self.Password
        }).json()
        if response.get('Status', -1) != "Success":
            self.logger.error(
                f'stariver detector 登录失败,错误信息:{response.get("ErrorMsg", "")}')
        token = response.get('Token', '')
        if token != '':
            self.logger.info(f'stariver detector 登录成功,token前10位:{token[:10]}')

        return token

    def adjust_font_size(self, original_font_size):
        new_font_size = original_font_size + self.font_size_offset
@@ -104,10 +135,37 @@ class StariverDetector(TextDetectorBase):
        return new_font_size

    def detect(self, img: np.ndarray) -> Tuple[np.ndarray, List[TextBlock]]:
        if not self.token or self.token == 'Replace with your token':
            self.logger.error(f'token 没有设置。当前token:{self.token}')
            raise ValueError('token 没有设置。')
        img_encoded = cv2.imencode('.jpg', img)[1]
        self.update_token_if_needed() # 在向服务器发送请求前尝试更新 Token
        if not self.token or self.token == '':
            self.logger.error(
                f'stariver detector token 没有设置。当前token:{self.token}')
            raise ValueError('stariver detector token 没有设置。')
        if self.low_accuracy_mode:
            self.logger.info('stariver detector 当前处于低精度模式。')
            short_side = 768
        else:
            short_side = 1536

        # 计算缩放比例
        height, width = img.shape[:2]
        scale = short_side / min(height, width)

        # 计算新的宽高
        new_width = int(width * scale)
        new_height = int(height * scale)

        # 按比例缩放图像
        if scale < 1:
            img_scaled = cv2.resize(
                img, (new_width, new_height), interpolation=cv2.INTER_AREA)
        else:
            img_scaled = img

        # 记录日志
        self.logger.debug(f'图像缩放比例:{scale},图像尺寸:{new_width}x{new_height}')

        # 编码图像为base64
        img_encoded = cv2.imencode('.jpg', img_scaled)[1]
        img_base64 = base64.b64encode(img_encoded).decode('utf-8')

        payload = {
@@ -124,19 +182,31 @@ class StariverDetector(TextDetectorBase):
        }
        if self.debug:
            payload_log = {k: v for k, v in payload.items() if k != 'image'}
            self.logger.debug(f'请求参数:{payload_log}')
            self.logger.debug(f'stariver detector 请求参数:{payload_log}')
        response = requests.post(self.url, json=payload)
        if response.status_code != 200:
            self.logger.error(f'请求失败,状态码:{response.status_code}')
            self.logger.error(
                f'stariver detector 请求失败,状态码:{response.status_code}')
            if response.json().get('Code', -1) != 0:
                self.logger.error(f'错误信息:{response.json().get("Message", "")}')
                self.logger.error(
                    f'stariver detector 错误信息:{response.json().get("Message", "")}')
                with open('stariver_ocr_error.txt', 'w', encoding='utf-8') as f:
                    f.write(response.text)
            raise ValueError('请求失败。')
            raise ValueError('stariver detector 请求失败。')
        response_data = response.json()['Data']

        blk_list = []
        for block in response_data.get('text_block', []):
            if scale < 1:
                xyxy = [int(min(coord[0] for coord in block['block_coordinate'].values()) / scale),
                        int(min(
                            coord[1] for coord in block['block_coordinate'].values()) / scale),
                        int(max(
                            coord[0] for coord in block['block_coordinate'].values()) / scale),
                        int(max(coord[1] for coord in block['block_coordinate'].values()) / scale)]
                lines = [np.array([[coord[pos][0] / scale, coord[pos][1] / scale] for pos in ['upper_left', 'upper_right',
                                                                                              'lower_right', 'lower_left']], dtype=np.float32) for coord in block['coordinate']]
            else:
                xyxy = [int(min(coord[0] for coord in block['block_coordinate'].values())),
                        int(min(coord[1]
                            for coord in block['block_coordinate'].values())),
@@ -145,14 +215,19 @@ class StariverDetector(TextDetectorBase):
                        int(max(coord[1] for coord in block['block_coordinate'].values()))]
                lines = [np.array([[coord[pos][0], coord[pos][1]] for pos in ['upper_left', 'upper_right',
                                                                              'lower_right', 'lower_left']], dtype=np.float32) for coord in block['coordinate']]
            texts = [text.replace('<skip>', '') for text in block.get('texts', [])]
            texts = [text.replace('<skip>', '')
                     for text in block.get('texts', [])]

            original_font_size = block.get('text_size', 0)

            font_size_recalculated = self.adjust_font_size(original_font_size)
            scaled_font_size = original_font_size / \
                scale if scale < 1 else original_font_size

            font_size_recalculated = self.adjust_font_size(scaled_font_size)

            if self.debug:
                self.logger.debug(f'原始字体大小:{original_font_size},修正后字体大小:{font_size_recalculated}')
                self.logger.debug(
                    f'原始字体大小:{original_font_size},修正后字体大小:{font_size_recalculated}')

            blk = TextBlock(
                xyxy=xyxy,
@@ -173,6 +248,12 @@ class StariverDetector(TextDetectorBase):

        mask = self._decode_base64_mask(response_data['mask'])
        mask = self.expand_mask(mask)

        # scale back to original size
        if scale < 1:
            mask = cv2.resize(mask, (width, height),
                              interpolation=cv2.INTER_NEAREST)
        self.logger.debug(f'检测结果mask尺寸:{mask.shape}')
        return mask, blk_list

    @staticmethod
@@ -214,3 +295,25 @@ class StariverDetector(TextDetectorBase):
        dilated_mask = (dilated_mask > 0).astype(np.uint8) * 255

        return dilated_mask

    def update_token_if_needed(self):
        if (self.User != self.register_username or 
            self.Password != self.register_password):
            if self.token_obtained == False:
                if "填入你的用户名" not in self.User and "填入你的密码。请注意,密码会明文保存,请勿在公共电脑上使用" not in self.Password:
                    if len(self.Password) > 7 and len(self.User) >= 1:
                        new_token = self.get_token()
                        if new_token:  # 确保新获取到有效token再更新信息
                            self.token = new_token
                            self.register_username = self.User
                            self.register_password = self.Password
                            self.token_obtained = True
                            self.logger.info("Token updated due to credential change.")

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
        if param_key == 'force_refresh_token':
            self.token_obtained = False  # 强制刷新token时,将标志位设置为False
            self.token = ''  # 强制刷新token时,将token置空
            self.register_username = None  # 强制刷新token时,将用户名置空
            self.register_password = None  # 强制刷新token时,将密码置空