Commit f57e4013 authored by PiDanShouRouZhouXD's avatar PiDanShouRouZhouXD
Browse files

exposed more params for stariver det&ocr

parent f0eef570
Loading
Loading
Loading
Loading
+54 −1
Original line number Diff line number Diff line
@@ -247,14 +247,66 @@ from .stariver_ocr import StariverOCR
class OCRStariver(OCRBase):
    params = {
        'token': 'Replace with your token',
        "refine":{
            'type': 'selector',
            'options': [True, False],
            'select': True
        },
        "filtrate":{
            'type': 'selector',
            'options': [True, False],
            'select': True
        },
        "disable_skip_area":{
            'type': 'selector',
            'options': [True, False],
            'select': True
        },
        "detect_scale": "3",
        "merge_threshold": "2",
        'description': '星河云(团子翻译器) OCR API'
    }

    @property
    def token(self):
        return self.params['token']
    
    @property
    def expand_ratio(self):
        return float(self.params['expand_ratio'])
    
    @property
    def refine(self):
        if self.params['refine']['select'] == 'True':
            return True
        elif self.params['refine']['select'] == 'False':
            return False    
    @property
    def filtrate(self):
        if self.params['filtrate']['select'] == 'True':
            return True
        elif self.params['filtrate']['select'] == 'False':
            return False
    @property
    def disable_skip_area(self):
        if self.params['disable_skip_area']['select'] == 'True':
            return True
        elif self.params['disable_skip_area']['select'] == 'False':
            return False
    @property
    def detect_scale(self):
        return int(self.params['detect_scale'])
    
    @property
    def merge_threshold(self):
        return float(self.params['merge_threshold'])

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.client = StariverOCR(self.params['token'])
        self.client = StariverOCR(self.token, refine=self.refine, filtrate=self.filtrate, disable_skip_area=self.disable_skip_area, detect_scale=self.detect_scale, merge_threshold=self.merge_threshold)

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        self.logger.debug(f'ocr_blk_list: {blk_list}')
        im_h, im_w = img.shape[:2]
        for blk in blk_list:
            x1, y1, x2, y2 = blk.xyxy
@@ -266,6 +318,7 @@ class OCRStariver(OCRBase):
                blk.text = ['']

    def ocr_img(self, img: np.ndarray) -> str:
        self.logger.debug(f'ocr_img: {img.shape}')
        if not self.params['token'] or self.params['token'] == 'Replace with your token':
            raise ValueError('token 没有设置。')
        return self.client.ocr(img)
+79 −62
Original line number Diff line number Diff line
@@ -6,9 +6,17 @@ import numpy as np
from utils.textblock import TextBlock

class StariverOCR:
    def __init__(self, token):
    
    def __init__(self, token, detect_scale=3, merge_threshold=0.5, refine=True, filtrate=True, disable_skip_area=True):
        self.token = token
        self.url = 'https://dl.ap-sh.starivercs.cn/v2/manga_trans/advanced/manga_ocr'
        self.detect_scale = detect_scale
        self.merge_threshold = merge_threshold
        self.refine = refine
        self.filtrate = filtrate
        self.disable_skip_area = disable_skip_area
        self.low_accuracy_mode = False


    def ocr(self, img: np.ndarray):
        img = cv2.imencode('.png', img)[1]
@@ -16,75 +24,84 @@ class StariverOCR:
        data = {
            "token": self.token,
            "mask": False,
            "refine": True,
            "filtrate": True,
            "disable_skip_area": True,
            "detect_scale": 3,
            "merge_threshold": 0.5,
            "low_accuracy_mode": False,
            "refine": self.refine,
            "filtrate": self.filtrate,
            "disable_skip_area": self.disable_skip_area,
            "detect_scale": self.detect_scale,
            "merge_threshold": self.merge_threshold,
            "low_accuracy_mode": self.low_accuracy_mode,
            "image": img_base64
        }
        response = requests.post(self.url, data=json.dumps(data))
        if response.status_code!= 200:
            self.logger.error(f'请求失败,状态码:{response.status_code}')
            if response.json().get('Code', -1) != 0:
                self.logger.error(f'错误信息:{response.json().get("Message", "")}')
                with open('stariver_ocr_error.txt', 'w', encoding='utf-8') as f:
                    f.write(response.text)
            raise ValueError('请求失败。')

        text_blocks = response.json()['Data']['text_block']
        texts = [text for block in text_blocks for text in block['texts']]
        return texts
    
    def ocr_verbose(self, img: np.ndarray):
        """
        测试用,返回mask和TextBlock列表
        """
        img_encoded = cv2.imencode('.jpg', img)[1]
        img_base64 = base64.b64encode(img_encoded).decode('utf-8')
    # def ocr_verbose(self, img: np.ndarray):
    #     """
    #     测试用,返回mask和TextBlock列表
    #     """
    #     img_encoded = cv2.imencode('.jpg', img)[1]
    #     img_base64 = base64.b64encode(img_encoded).decode('utf-8')
        
        payload = {
            "token": self.token,
            "mask": True,
            "refine": True,
            "filtrate": True,
            "disable_skip_area": True,
            "detect_scale": 3,
            "merge_threshold": 0.5,
            "low_accuracy_mode": False,
            "image": img_base64
        }
        response = requests.post(self.url, json=payload)
        response_data = response.json()['Data']
    #     payload = {
    #         "token": self.token,
    #         "mask": True,
    #         "refine": self.refine,
    #         "filtrate": self.filtrate,
    #         "disable_skip_area": self.disable_skip_area,
    #         "detect_scale": self.detect_scale,
    #         "merge_threshold": self.merge_threshold,
    #         "low_accuracy_mode": self.low_accuracy_mode,
    #         "image": img_base64
    #     }

    #     response = requests.post(self.url, json=payload)
    #     response_data = response.json()['Data']

        blk_list = []
        for block in response_data.get('text_block', []):
            xyxy = [int(min(coord[0] for coord in block['block_coordinate'].values())),
                    int(min(coord[1] for coord in block['block_coordinate'].values())),
                    int(max(coord[0] for coord in block['block_coordinate'].values())),
                    int(max(coord[1] for coord in block['block_coordinate'].values()))]
            lines = [np.array([[coord[pos][0], coord[pos][1]] for pos in ['upper_left', 'upper_right', 'lower_right', 'lower_left']], dtype=np.float32) for coord in block['coordinate']]
            texts = block.get('texts', '')
            blk = TextBlock(
                xyxy=xyxy,
                lines=lines,
                language=block.get('language', 'unknown'),
                vertical=block.get('is_vertical', False),
                font_size=block.get('text_size', 0),
                distance=np.array([0, 0], dtype=np.float32),
                angle=0,
                vec=np.array([0, 0], dtype=np.float32),
                norm=0,
                merged=False,
                text=texts,
                fg_colors=np.array(block.get('foreground_color', [0, 0, 0]), dtype=np.float32),
                bg_colors=np.array(block.get('background_color', [0, 0, 0]), dtype=np.float32)
            )
            # print(blk.to_dict())
            blk_list.append(blk)
    #     blk_list = []
    #     for block in response_data.get('text_block', []):
    #         xyxy = [int(min(coord[0] for coord in block['block_coordinate'].values())),
    #                 int(min(coord[1] for coord in block['block_coordinate'].values())),
    #                 int(max(coord[0] for coord in block['block_coordinate'].values())),
    #                 int(max(coord[1] for coord in block['block_coordinate'].values()))]
    #         lines = [np.array([[coord[pos][0], coord[pos][1]] for pos in ['upper_left', 'upper_right', 'lower_right', 'lower_left']], dtype=np.float32) for coord in block['coordinate']]
    #         texts = block.get('texts', '')
    #         blk = TextBlock(
    #             xyxy=xyxy,
    #             lines=lines,
    #             language=block.get('language', 'unknown'),
    #             vertical=block.get('is_vertical', False),
    #             font_size=block.get('text_size', 0),
    #             distance=np.array([0, 0], dtype=np.float32),
    #             angle=0,
    #             vec=np.array([0, 0], dtype=np.float32),
    #             norm=0,
    #             merged=False,
    #             text=texts,
    #             fg_colors=np.array(block.get('foreground_color', [0, 0, 0]), dtype=np.float32),
    #             bg_colors=np.array(block.get('background_color', [0, 0, 0]), dtype=np.float32)
    #         )
    #         # print(blk.to_dict())
    #         blk_list.append(blk)
        
        mask = self._decode_base64_mask(response_data['mask'])
        return mask, blk_list
    #     mask = self._decode_base64_mask(response_data['mask'])
    #     return mask, blk_list

    @staticmethod
    def _decode_base64_mask(base64_str: str) -> np.ndarray:
        img_data = base64.b64decode(base64_str)
        img_array = np.frombuffer(img_data, dtype=np.uint8)
        mask = cv2.imdecode(img_array, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            print("Error decoding the mask.")
            return None
        return mask
 No newline at end of file
    # @staticmethod
    # def _decode_base64_mask(base64_str: str) -> np.ndarray:
    #     img_data = base64.b64decode(base64_str)
    #     img_array = np.frombuffer(img_data, dtype=np.uint8)
    #     mask = cv2.imdecode(img_array, cv2.IMREAD_GRAYSCALE)
    #     if mask is None:
    #         print("Error decoding the mask.")
    #         return None
    #     return mask
 No newline at end of file
+103 −29
Original line number Diff line number Diff line
import os
import base64
from re import T
import requests
import numpy as np
import cv2
@@ -106,12 +106,35 @@ class ComicTextDetector(TextDetectorBase):
                    self.model.load_model(CTD_ONNX_PATH)
            self.model.detect_size = self.detect_size


@register_textdetectors('stariver_ocr')
class StariverDetector(TextDetectorBase):

    params = {
        'token': "Replace with your token",
        'expand_ratio': "0.01",
        "refine": {
            'type': 'selector',
            'options': [True, False],
            'select': True
        },
        "filtrate": {
            'type': 'selector',
            'options': [True, False],
            'select': True
        },
        "disable_skip_area": {
            'type': 'selector',
            'options': [True, False],
            'select': True
        },
        "detect_scale": "3",
        "merge_threshold": "2.0",
        "low_accuracy_mode": {
            'type': 'selector',
            'options': [True, False],
            'select': False
        },
        'description': '星河云(团子翻译器) OCR 文字检测器'
    }

@@ -121,11 +144,49 @@ class StariverDetector(TextDetectorBase):

    @property
    def expand_ratio(self):
        return self.params['expand_ratio'].eval()
        return float(self.params['expand_ratio'])

    @property
    def refine(self):
        if self.params['refine']['select'] == 'True':
            return True
        elif self.params['refine']['select'] == 'False':
            return False

    @property
    def filtrate(self):
        if self.params['filtrate']['select'] == 'True':
            return True
        elif self.params['filtrate']['select'] == 'False':
            return False

    @property
    def disable_skip_area(self):
        if self.params['disable_skip_area']['select'] == 'True':
            return True
        elif self.params['disable_skip_area']['select'] == 'False':
            return False

    @property
    def detect_scale(self):
        return int(self.params['detect_scale'])

    @property
    def merge_threshold(self):
        return float(self.params['merge_threshold'])

    @property
    def low_accuracy_mode(self):
        if self.params['low_accuracy_mode']['select'] == 'True':
            return True
        elif self.params['low_accuracy_mode']['select'] == 'False':
            return False

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.url = 'https://dl.ap-sh.starivercs.cn/v2/manga_trans/advanced/manga_ocr'
        self.name = 'StariverDetector'
        self.debug = True
        # self.name = 'StariverDetector'

    def detect(self, img: np.ndarray) -> Tuple[np.ndarray, List[TextBlock]]:
        if not self.token or self.token == 'Replace with your token':
@@ -137,24 +198,37 @@ class StariverDetector(TextDetectorBase):
        payload = {
            "token": self.token,
            "mask": True,
            "refine": True,
            "filtrate": True,
            "disable_skip_area": True,
            "detect_scale": 3,
            "merge_threshold": 0.5,
            "low_accuracy_mode": False,
            "refine": self.refine,
            "filtrate": self.filtrate,
            "disable_skip_area": self.disable_skip_area,
            "detect_scale": self.detect_scale,
            "merge_threshold": self.merge_threshold,
            "low_accuracy_mode": self.low_accuracy_mode,
            "image": img_base64
        }
        if self.debug:
            payload_log = {k: v for k, v in payload.items() if k != 'image'}
            self.logger.debug(f'请求参数:{payload_log}')
        response = requests.post(self.url, json=payload)
        if response.status_code != 200:
            self.logger.error(f'请求失败,状态码:{response.status_code}')
            if response.json().get('Code', -1) != 0:
                self.logger.error(f'错误信息:{response.json().get("Message", "")}')
                with open('stariver_ocr_error.txt', 'w', encoding='utf-8') as f:
                    f.write(response.text)
            raise ValueError('请求失败。')
        response_data = response.json()['Data']

        blk_list = []
        for block in response_data.get('text_block', []):
            xyxy = [int(min(coord[0] for coord in block['block_coordinate'].values())),
                    int(min(coord[1] for coord in block['block_coordinate'].values())),
                    int(max(coord[0] for coord in block['block_coordinate'].values())),
                    int(min(coord[1]
                        for coord in block['block_coordinate'].values())),
                    int(max(coord[0]
                        for coord in block['block_coordinate'].values())),
                    int(max(coord[1] for coord in block['block_coordinate'].values()))]
            lines = [np.array([[coord[pos][0], coord[pos][1]] for pos in ['upper_left', 'upper_right', 'lower_right', 'lower_left']], dtype=np.float32) for coord in block['coordinate']]
            lines = [np.array([[coord[pos][0], coord[pos][1]] for pos in ['upper_left', 'upper_right',
                              'lower_right', 'lower_left']], dtype=np.float32) for coord in block['coordinate']]
            texts = block.get('texts', '')
            blk = TextBlock(
                xyxy=xyxy,
@@ -162,16 +236,16 @@ class StariverDetector(TextDetectorBase):
                language=block.get('language', 'unknown'),
                vertical=block.get('is_vertical', False),
                font_size=block.get('text_size', 0),
                distance=np.array([0, 0], dtype=np.float32),
                angle=0,
                vec=np.array([0, 0], dtype=np.float32),
                norm=0,
                merged=False,
                
                text=texts,
                fg_colors=np.array(block.get('foreground_color', [0, 0, 0]), dtype=np.float32),
                bg_colors=np.array(block.get('background_color', [0, 0, 0]), dtype=np.float32)
                fg_colors=np.array(block.get('foreground_color', [
                                   0, 0, 0]), dtype=np.float32),
                bg_colors=np.array(block.get('background_color', [
                                   0, 0, 0]), dtype=np.float32)
            )
            blk_list.append(blk)
            if self.debug:
                self.logger.debug(f'检测到文本块:{blk.to_dict()}')

        mask = self._decode_base64_mask(response_data['mask'])
        mask = self.expand_mask(mask)