Commit 8af61e35 authored by PiDanShouRouZhouXD's avatar PiDanShouRouZhouXD
Browse files

update stariver

parent 4ceed58a
Loading
Loading
Loading
Loading
+64 −13
Original line number Diff line number Diff line
@@ -11,8 +11,12 @@ from .base import register_OCR, OCRBase, TextBlock
@register_OCR('stariver_ocr')
class OCRStariver(OCRBase):
    params = {
        'url': 'https://dl.ap-sh.starivercs.cn/v2/manga_trans/advanced/manga_ocr',
        'token': 'Replace with your token',
        'User': "填入你的用户名",
        'Password': "填入你的密码。请注意,密码会明文保存,请勿在公共电脑上使用",
        'force_refresh_token': {
            'type': 'checkbox',
            'value': False
        },
        "refine":{
            'type': 'checkbox',
            'value': True
@@ -40,8 +44,16 @@ class OCRStariver(OCRBase):
    }

    @property
    def token(self):
        return self.params['token']
    def User(self):
        return self.params['User']
    
    @property
    def Password(self):
        return self.params['Password']
    
    @property
    def force_refresh_token(self):
        return self.params['force_refresh_token']['value']
    
    @property
    def expand_ratio(self):
@@ -71,6 +83,40 @@ class OCRStariver(OCRBase):
    def force_expand(self):
        return self.params['force_expand']['value']
    
    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.token = ''
        self.url = "https://dl.ap-sh.starivercs.cn/v2/manga_trans/advanced/manga_ocr"
        self.token_obtained = False  # 添加一个标志位来判断token是否已经获取过
        
        # 在初始化时尝试获取token
        if not self.token_obtained:
            self.update_token_if_needed()
            self.token_obtained = True  # 将标志位设置为True,表示已获取token

    def update_token_if_needed(self):
        if "填入你的用户名" not in self.User and "填入你的密码。请注意,密码会明文保存,请勿在公共电脑上使用" not in self.Password:
            if not self.token_obtained or self.force_refresh_token:  # 检查标志位,只有在第一次运行时获取token
                if len(self.Password) > 7 and len(self.User) >= 1:
                    self.token = self.get_token()
                    if self.token != '':
                        self.token_obtained = True  # 获取成功后,将标志位设置为True
        else:
            self.logger.warning('stariver ocr 用户名或密码为空,无法更新token。')

    def get_token(self):
        response = requests.post('https://capiv1.ap-sh.starivercs.cn/OCR/Admin/Login', json={
            "User": self.User,
            "Password": self.Password
        }).json()
        if response.get('Status', -1) != "Success":
            self.logger.error(f'stariver ocr 登录失败,错误信息:{response.get("ErrorMsg", "")}')
        token = response.get('Token', '')
        if token != '':
            self.logger.info(f'登录成功,token前10位:{token[:10]}')

        return token

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        im_h, im_w = img.shape[:2]
        for blk in blk_list:
@@ -88,11 +134,8 @@ class OCRStariver(OCRBase):

    def ocr(self, img: np.ndarray) -> str:
        
        if not self.params['token'] or self.params['token'] == 'Replace with your token':
            raise ValueError('token 没有设置。')
        
        payload = {
            "token": self.params['token'],
            "token": self.token,
            "mask": False,
            "refine": self.refine,
            "filtrate": self.filtrate,
@@ -110,19 +153,19 @@ class OCRStariver(OCRBase):
        response = requests.post(self.url, data=json.dumps(payload))

        if response.status_code != 200:
            print(f'请求失败,状态码:{response.status_code}')
            print(f'stariver ocr 请求失败,状态码:{response.status_code}')
            if response.json().get('Code', -1) != 0:
                print(f'错误信息:{response.json().get("Message", "")}')
                print(f'stariver ocr 错误信息:{response.json().get("Message", "")}')
                with open('stariver_ocr_error.txt', 'w', encoding='utf-8') as f:
                    f.write(response.text)
            raise ValueError('请求失败。')
            raise ValueError('stariver ocr 请求失败。')

        response_data = response.json()['Data']

        if self.debug:
            id = response.json().get('RequestID', '')
            file_name = f"stariver_ocr_response_{id}.json"
            print(f"请求成功,响应数据已保存至{file_name}")
            print(f"stariver ocr 请求成功,响应数据已保存至{file_name}")
            with open(file_name, 'w', encoding='utf-8') as f:
                json.dump(response_data, f, ensure_ascii=False, indent=4)

@@ -130,3 +173,11 @@ class OCRStariver(OCRBase):
                      for block in response_data.get('text_block', [])]
        texts_str = "".join(texts_list).replace('<skip>', '')
        return texts_str

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
        if param_key == 'User' or param_key == 'Password':
            if not self.token_obtained or self.force_refresh_token:  # 检查标志位,只有在第一次运行时获取token
                self.update_token_if_needed()
            if param_key == 'force_refresh_token':
                self.token_obtained = False  # 强制刷新token时,将标志位设置为False
 No newline at end of file
+126 −27
Original line number Diff line number Diff line
@@ -11,7 +11,12 @@ from .base import register_textdetectors, TextDetectorBase, TextBlock
class StariverDetector(TextDetectorBase):

    params = {
        'token': "Replace with your token",
        'User': "填入你的用户名",
        'Password': "填入你的密码。请注意,密码会明文保存,请勿在公共电脑上使用",
        'force_refresh_token': {
            'type': 'checkbox',
            'value': False
        },
        'expand_ratio': "0.01",
        "refine": {
            'type': 'checkbox',
@@ -42,8 +47,16 @@ class StariverDetector(TextDetectorBase):
    }

    @property
    def token(self):
        return self.params['token']
    def User(self):
        return self.params['User']

    @property
    def Password(self):
        return self.params['Password']

    @property
    def force_refresh_token(self):
        return self.params['force_refresh_token']['value']

    @property
    def expand_ratio(self):
@@ -93,7 +106,27 @@ class StariverDetector(TextDetectorBase):
        super().__init__(**params)
        self.url = 'https://dl.ap-sh.starivercs.cn/v2/manga_trans/advanced/manga_ocr'
        self.debug = False
        # self.name = 'StariverDetector'
        self.token = ''
        self.token_obtained = False  # 添加一个标志位来判断token是否已经获取过

        # 在初始化时尝试获取token
        if not self.token_obtained:
            self.update_token_if_needed()
            self.token_obtained = True  # 将标志位设置为True,表示已获取token

    def get_token(self):
        response = requests.post('https://capiv1.ap-sh.starivercs.cn/OCR/Admin/Login', json={
            "User": self.User,
            "Password": self.Password
        }).json()
        if response.get('Status', -1) != "Success":
            self.logger.error(
                f'stariver detector 登录失败,错误信息:{response.get("ErrorMsg", "")}')
        token = response.get('Token', '')
        if token != '':
            self.logger.info(f'stariver detector 登录成功,token前10位:{token[:10]}')

        return token

    def adjust_font_size(self, original_font_size):
        new_font_size = original_font_size + self.font_size_offset
@@ -104,10 +137,36 @@ class StariverDetector(TextDetectorBase):
        return new_font_size

    def detect(self, img: np.ndarray) -> Tuple[np.ndarray, List[TextBlock]]:
        if not self.token or self.token == 'Replace with your token':
            self.logger.error(f'token 没有设置。当前token:{self.token}')
            raise ValueError('token 没有设置。')
        img_encoded = cv2.imencode('.jpg', img)[1]
        if not self.token or self.token == '':
            self.logger.error(
                f'stariver detector token 没有设置。当前token:{self.token}')
            raise ValueError('stariver detector token 没有设置。')
        if self.low_accuracy_mode:
            self.logger.info('stariver detector 当前处于低精度模式。')
            short_side = 768
        else:
            short_side = 1536

        # 计算缩放比例
        height, width = img.shape[:2]
        scale = short_side / min(height, width)

        # 计算新的宽高
        new_width = int(width * scale)
        new_height = int(height * scale)

        # 按比例缩放图像
        if scale < 1:
            img_scaled = cv2.resize(
                img, (new_width, new_height), interpolation=cv2.INTER_AREA)
        else:
            img_scaled = img

        # 记录日志
        self.logger.debug(f'图像缩放比例:{scale},图像尺寸:{new_width}x{new_height}')

        # 编码图像为base64
        img_encoded = cv2.imencode('.jpg', img_scaled)[1]
        img_base64 = base64.b64encode(img_encoded).decode('utf-8')

        payload = {
@@ -124,19 +183,31 @@ class StariverDetector(TextDetectorBase):
        }
        if self.debug:
            payload_log = {k: v for k, v in payload.items() if k != 'image'}
            self.logger.debug(f'请求参数:{payload_log}')
            self.logger.debug(f'stariver detector 请求参数:{payload_log}')
        response = requests.post(self.url, json=payload)
        if response.status_code != 200:
            self.logger.error(f'请求失败,状态码:{response.status_code}')
            self.logger.error(
                f'stariver detector 请求失败,状态码:{response.status_code}')
            if response.json().get('Code', -1) != 0:
                self.logger.error(f'错误信息:{response.json().get("Message", "")}')
                self.logger.error(
                    f'stariver detector 错误信息:{response.json().get("Message", "")}')
                with open('stariver_ocr_error.txt', 'w', encoding='utf-8') as f:
                    f.write(response.text)
            raise ValueError('请求失败。')
            raise ValueError('stariver detector 请求失败。')
        response_data = response.json()['Data']

        blk_list = []
        for block in response_data.get('text_block', []):
            if scale < 1:
                xyxy = [int(min(coord[0] for coord in block['block_coordinate'].values()) / scale),
                        int(min(
                            coord[1] for coord in block['block_coordinate'].values()) / scale),
                        int(max(
                            coord[0] for coord in block['block_coordinate'].values()) / scale),
                        int(max(coord[1] for coord in block['block_coordinate'].values()) / scale)]
                lines = [np.array([[coord[pos][0] / scale, coord[pos][1] / scale] for pos in ['upper_left', 'upper_right',
                                                                                              'lower_right', 'lower_left']], dtype=np.float32) for coord in block['coordinate']]
            else:
                xyxy = [int(min(coord[0] for coord in block['block_coordinate'].values())),
                        int(min(coord[1]
                            for coord in block['block_coordinate'].values())),
@@ -145,14 +216,19 @@ class StariverDetector(TextDetectorBase):
                        int(max(coord[1] for coord in block['block_coordinate'].values()))]
                lines = [np.array([[coord[pos][0], coord[pos][1]] for pos in ['upper_left', 'upper_right',
                                                                              'lower_right', 'lower_left']], dtype=np.float32) for coord in block['coordinate']]
            texts = [text.replace('<skip>', '') for text in block.get('texts', [])]
            texts = [text.replace('<skip>', '')
                     for text in block.get('texts', [])]

            original_font_size = block.get('text_size', 0)

            font_size_recalculated = self.adjust_font_size(original_font_size)
            scaled_font_size = original_font_size / \
                scale if scale < 1 else original_font_size

            font_size_recalculated = self.adjust_font_size(scaled_font_size)

            if self.debug:
                self.logger.debug(f'原始字体大小:{original_font_size},修正后字体大小:{font_size_recalculated}')
                self.logger.debug(
                    f'原始字体大小:{original_font_size},修正后字体大小:{font_size_recalculated}')

            blk = TextBlock(
                xyxy=xyxy,
@@ -173,6 +249,12 @@ class StariverDetector(TextDetectorBase):

        mask = self._decode_base64_mask(response_data['mask'])
        mask = self.expand_mask(mask)

        # scale back to original size
        if scale < 1:
            mask = cv2.resize(mask, (width, height),
                              interpolation=cv2.INTER_NEAREST)
        self.logger.debug(f'检测结果mask尺寸:{mask.shape}')
        return mask, blk_list

    @staticmethod
@@ -214,3 +296,20 @@ class StariverDetector(TextDetectorBase):
        dilated_mask = (dilated_mask > 0).astype(np.uint8) * 255

        return dilated_mask

    def update_token_if_needed(self):
        if "填入你的用户名" not in self.User and "填入你的密码。请注意,密码会明文保存,请勿在公共电脑上使用" not in self.Password:
            if not self.token_obtained or self.force_refresh_token:  # 检查标志位,只有在第一次运行时获取token
                if len(self.Password) > 7 and len(self.User) >= 1:
                    self.token = self.get_token()
                    self.token_obtained = True  # 获取成功后,将标志位设置为True
        else:
            self.logger.warning('stariver detector 用户名或密码为空,无法更新token。')

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
        if param_key == 'User' or param_key == 'Password':
            if not self.token_obtained or self.force_refresh_token:  # 检查标志位,只有在第一次运行时获取token
                self.update_token_if_needed()
        if param_key == 'force_refresh_token':
            self.token_obtained = False  # 强制刷新token时,将标志位设置为False