Commit 3be64ed4 authored by PiDanShouRouZhouXD's avatar PiDanShouRouZhouXD
Browse files

fix ocr, add none ocr

parent f57e4013
Loading
Loading
Loading
Loading
+31 −5
Original line number Diff line number Diff line
@@ -37,7 +37,9 @@ class OCRBase(BaseModule):
            blk_list = [blk_list]

        for blk in blk_list:
            if self.name != 'none_ocr':
                blk.text = []
                
        self._ocr_blk_list(img, blk_list)
        for callback_name, callback in self._postprocess_hooks.items():
            callback(textblocks=blk_list, img=img, ocr_module=self)
@@ -264,6 +266,11 @@ class OCRStariver(OCRBase):
        },
        "detect_scale": "3",
        "merge_threshold": "2",
        "force_expand":{
            'type': 'selector',
            'options': [True, False],
            'select': False
        },
        'description': '星河云(团子翻译器) OCR API'
    }

@@ -301,12 +308,18 @@ class OCRStariver(OCRBase):
    def merge_threshold(self):
        return float(self.params['merge_threshold'])
    
    @property
    def force_expand(self):
        if self.params['force_expand']['select'] == 'True':
            return True
        elif self.params['force_expand']['select'] == 'False':
            return False

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.client = StariverOCR(self.token, refine=self.refine, filtrate=self.filtrate, disable_skip_area=self.disable_skip_area, detect_scale=self.detect_scale, merge_threshold=self.merge_threshold)
        self.client = StariverOCR(self.token, refine=self.refine, filtrate=self.filtrate, disable_skip_area=self.disable_skip_area, detect_scale=self.detect_scale, merge_threshold=self.merge_threshold, force_expand=self.force_expand)

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        self.logger.debug(f'ocr_blk_list: {blk_list}')
        im_h, im_w = img.shape[:2]
        for blk in blk_list:
            x1, y1, x2, y2 = blk.xyxy
@@ -319,14 +332,27 @@ class OCRStariver(OCRBase):

    def ocr_img(self, img: np.ndarray) -> str:
        self.logger.debug(f'ocr_img: {img.shape}')
        if not self.params['token'] or self.params['token'] == 'Replace with your token':
            raise ValueError('token 没有设置。')
        return self.client.ocr(img)

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
        self.client.token = self.params['token']

@register_OCR('none_ocr')
class OCRNone(OCRBase):
    def __init__(self, **params) -> None:
        super().__init__(**params)

    params = {
        'NOTICE': 'Not a OCR, just return original text.',
        'description': 'Not a OCR, just return original text.'
    }

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        pass

    def ocr_img(self, img: np.ndarray) -> str:
        return ''
    
import platform
if platform.mac_ver()[0] >= '10.15':
+37 −88
Original line number Diff line number Diff line
@@ -3,105 +3,54 @@ import requests
import json
import base64
import numpy as np
from utils.textblock import TextBlock


class StariverOCR:

    def __init__(self, token, detect_scale=3, merge_threshold=0.5, refine=True, filtrate=True, disable_skip_area=True):
    def __init__(self, token, detect_scale=3, merge_threshold=0.5, refine=True, filtrate=True, disable_skip_area=True, force_expand=False):
        self.token = token
        self.url = 'https://dl.ap-sh.starivercs.cn/v2/manga_trans/advanced/manga_ocr'
        self.detect_scale = detect_scale
        self.merge_threshold = merge_threshold
        self.refine = refine
        self.filtrate = filtrate
        self.disable_skip_area = disable_skip_area
        self.low_accuracy_mode = False


    def ocr(self, img: np.ndarray):
        img = cv2.imencode('.png', img)[1]
        img_base64 = base64.b64encode(img).decode('utf-8')
        data = {
        self.debug = False
        self.params = {
            "token": self.token,
            "mask": False,
            "refine": self.refine,
            "filtrate": self.filtrate,
            "disable_skip_area": self.disable_skip_area,
            "detect_scale": self.detect_scale,
            "merge_threshold": self.merge_threshold,
            "low_accuracy_mode": self.low_accuracy_mode,
            "image": img_base64
            "refine": refine,
            "filtrate": filtrate,
            "disable_skip_area": disable_skip_area,
            "detect_scale": detect_scale,
            "merge_threshold": merge_threshold,
            "low_accuracy_mode": True,
            "force_expand": force_expand
        }
        response = requests.post(self.url, data=json.dumps(data))

    def ocr(self, img: np.ndarray) -> str:
        if not self.params['token'] or self.params['token'] == 'Replace with your token':
            raise ValueError('token 没有设置。')

        img_base64 = base64.b64encode(
            cv2.imencode('.jpg', img)[1]).decode('utf-8')
        self.params["image"] = img_base64

        response = requests.post(self.url, data=json.dumps(self.params))

        if response.status_code != 200:
            self.logger.error(f'请求失败,状态码:{response.status_code}')
            print(f'请求失败,状态码:{response.status_code}')
            if response.json().get('Code', -1) != 0:
                self.logger.error(f'错误信息:{response.json().get("Message", "")}')
                print(f'错误信息:{response.json().get("Message", "")}')
                with open('stariver_ocr_error.txt', 'w', encoding='utf-8') as f:
                    f.write(response.text)
            raise ValueError('请求失败。')

        text_blocks = response.json()['Data']['text_block']
        texts = [text for block in text_blocks for text in block['texts']]
        return texts
    
    # def ocr_verbose(self, img: np.ndarray):
    #     """
    #     测试用,返回mask和TextBlock列表
    #     """
    #     img_encoded = cv2.imencode('.jpg', img)[1]
    #     img_base64 = base64.b64encode(img_encoded).decode('utf-8')
        
    #     payload = {
    #         "token": self.token,
    #         "mask": True,
    #         "refine": self.refine,
    #         "filtrate": self.filtrate,
    #         "disable_skip_area": self.disable_skip_area,
    #         "detect_scale": self.detect_scale,
    #         "merge_threshold": self.merge_threshold,
    #         "low_accuracy_mode": self.low_accuracy_mode,
    #         "image": img_base64
    #     }

    #     response = requests.post(self.url, json=payload)
    #     response_data = response.json()['Data']

    #     blk_list = []
    #     for block in response_data.get('text_block', []):
    #         xyxy = [int(min(coord[0] for coord in block['block_coordinate'].values())),
    #                 int(min(coord[1] for coord in block['block_coordinate'].values())),
    #                 int(max(coord[0] for coord in block['block_coordinate'].values())),
    #                 int(max(coord[1] for coord in block['block_coordinate'].values()))]
    #         lines = [np.array([[coord[pos][0], coord[pos][1]] for pos in ['upper_left', 'upper_right', 'lower_right', 'lower_left']], dtype=np.float32) for coord in block['coordinate']]
    #         texts = block.get('texts', '')
    #         blk = TextBlock(
    #             xyxy=xyxy,
    #             lines=lines,
    #             language=block.get('language', 'unknown'),
    #             vertical=block.get('is_vertical', False),
    #             font_size=block.get('text_size', 0),
    #             distance=np.array([0, 0], dtype=np.float32),
    #             angle=0,
    #             vec=np.array([0, 0], dtype=np.float32),
    #             norm=0,
    #             merged=False,
    #             text=texts,
    #             fg_colors=np.array(block.get('foreground_color', [0, 0, 0]), dtype=np.float32),
    #             bg_colors=np.array(block.get('background_color', [0, 0, 0]), dtype=np.float32)
    #         )
    #         # print(blk.to_dict())
    #         blk_list.append(blk)
        response_data = response.json()['Data']

    #     mask = self._decode_base64_mask(response_data['mask'])
    #     return mask, blk_list
        if self.debug:
            id = response.json().get('RequestID', '')
            file_name = f"stariver_ocr_response_{id}.json"
            print(f"请求成功,响应数据已保存至{file_name}")
            with open(file_name, 'w', encoding='utf-8') as f:
                json.dump(response_data, f, ensure_ascii=False, indent=4)

    # @staticmethod
    # def _decode_base64_mask(base64_str: str) -> np.ndarray:
    #     img_data = base64.b64decode(base64_str)
    #     img_array = np.frombuffer(img_data, dtype=np.uint8)
    #     mask = cv2.imdecode(img_array, cv2.IMREAD_GRAYSCALE)
    #     if mask is None:
    #         print("Error decoding the mask.")
    #         return None
    #     return mask
 No newline at end of file
        texts_list = ["".join(block.get('texts', '')).strip()
                      for block in response_data.get('text_block', [])]
        texts_str = "".join(texts_list)
        return texts_str
+19 −2
Original line number Diff line number Diff line
@@ -135,6 +135,11 @@ class StariverDetector(TextDetectorBase):
            'options': [True, False],
            'select': False
        },
        "force_expand":{
            'type': 'selector',
            'options': [True, False],
            'select': False
        },
        'description': '星河云(团子翻译器) OCR 文字检测器'
    }

@@ -182,10 +187,17 @@ class StariverDetector(TextDetectorBase):
        elif self.params['low_accuracy_mode']['select'] == 'False':
            return False
        
    @property
    def force_expand(self):
        if self.params['force_expand']['select'] == 'True':
            return True
        elif self.params['force_expand']['select'] == 'False':
            return False

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.url = 'https://dl.ap-sh.starivercs.cn/v2/manga_trans/advanced/manga_ocr'
        self.debug = True
        self.debug = False
        # self.name = 'StariverDetector'

    def detect(self, img: np.ndarray) -> Tuple[np.ndarray, List[TextBlock]]:
@@ -204,6 +216,7 @@ class StariverDetector(TextDetectorBase):
            "detect_scale": self.detect_scale,
            "merge_threshold": self.merge_threshold,
            "low_accuracy_mode": self.low_accuracy_mode,
            "force_expand": self.force_expand,
            "image": img_base64
        }
        if self.debug:
@@ -265,6 +278,10 @@ class StariverDetector(TextDetectorBase):
        :param expand_ratio: 扩展比例,默认值为0.01
        :return: 扩展后的mask
        """

        if expand_ratio == 0:
            return mask
        
        # 确保mask是二值图像(只含0和255)
        mask = (mask > 0).astype(np.uint8) * 255