Commit 9e8c5cf9 authored by PiDanShouRouZhouXD's avatar PiDanShouRouZhouXD
Browse files

Add: Stariver(星河云/团子OCR) Detector and OCR

parent da6b8020
Loading
Loading
Loading
Loading
+31 −0
Original line number Diff line number Diff line
@@ -242,6 +242,37 @@ class OCRMIT48px(OCRBase):
        if self.device != device:
            self.model.to(device)

from .stariver_ocr import StariverOCR
@register_OCR('stariver_ocr')
class OCRStariver(OCRBase):
    params = {
        'token': 'Replace with your token',
        'description': '星河云(团子翻译器) OCR API'
    }

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.client = StariverOCR(self.params['token'])

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        im_h, im_w = img.shape[:2]
        for blk in blk_list:
            x1, y1, x2, y2 = blk.xyxy
            if y2 < im_h and x2 < im_w and \
                    x1 > 0 and y1 > 0 and x1 < x2 and y1 < y2:
                blk.text = self.client.ocr(img[y1:y2, x1:x2])
            else:
                logging.warning('invalid textbbox to target img')
                blk.text = ['']

    def ocr_img(self, img: np.ndarray) -> str:
        if not self.params['token'] or self.params['token'] == 'Replace with your token':
            raise ValueError('token 没有设置。')
        return self.client.ocr(img)

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
        self.client.token = self.params['token']

    
import platform
+90 −0
Original line number Diff line number Diff line
import cv2
import requests
import json
import base64
import numpy as np
from utils.textblock import TextBlock

class StariverOCR:
    def __init__(self, token):
        self.token = token
        self.url = 'https://dl.ap-sh.starivercs.cn/v2/manga_trans/advanced/manga_ocr'

    def ocr(self, img: np.ndarray):
        img = cv2.imencode('.png', img)[1]
        img_base64 = base64.b64encode(img).decode('utf-8')
        data = {
            "token": self.token,
            "mask": False,
            "refine": True,
            "filtrate": True,
            "disable_skip_area": True,
            "detect_scale": 3,
            "merge_threshold": 0.5,
            "low_accuracy_mode": False,
            "image": img_base64
        }
        response = requests.post(self.url, data=json.dumps(data))
        text_blocks = response.json()['Data']['text_block']
        texts = [text for block in text_blocks for text in block['texts']]
        return texts
    
    def ocr_verbose(self, img: np.ndarray):
        """
        测试用,返回mask和TextBlock列表
        """
        img_encoded = cv2.imencode('.jpg', img)[1]
        img_base64 = base64.b64encode(img_encoded).decode('utf-8')
        
        payload = {
            "token": self.token,
            "mask": True,
            "refine": True,
            "filtrate": True,
            "disable_skip_area": True,
            "detect_scale": 3,
            "merge_threshold": 0.5,
            "low_accuracy_mode": False,
            "image": img_base64
        }
        response = requests.post(self.url, json=payload)
        response_data = response.json()['Data']

        blk_list = []
        for block in response_data.get('text_block', []):
            xyxy = [int(min(coord[0] for coord in block['block_coordinate'].values())),
                    int(min(coord[1] for coord in block['block_coordinate'].values())),
                    int(max(coord[0] for coord in block['block_coordinate'].values())),
                    int(max(coord[1] for coord in block['block_coordinate'].values()))]
            lines = [np.array([[coord[pos][0], coord[pos][1]] for pos in ['upper_left', 'upper_right', 'lower_right', 'lower_left']], dtype=np.float32) for coord in block['coordinate']]
            texts = block.get('texts', '')
            blk = TextBlock(
                xyxy=xyxy,
                lines=lines,
                language=block.get('language', 'unknown'),
                vertical=block.get('is_vertical', False),
                font_size=block.get('text_size', 0),
                distance=np.array([0, 0], dtype=np.float32),
                angle=0,
                vec=np.array([0, 0], dtype=np.float32),
                norm=0,
                merged=False,
                text=texts,
                fg_colors=np.array(block.get('foreground_color', [0, 0, 0]), dtype=np.float32),
                bg_colors=np.array(block.get('background_color', [0, 0, 0]), dtype=np.float32)
            )
            # print(blk.to_dict())
            blk_list.append(blk)
        
        mask = self._decode_base64_mask(response_data['mask'])
        return mask, blk_list

    @staticmethod
    def _decode_base64_mask(base64_str: str) -> np.ndarray:
        img_data = base64.b64decode(base64_str)
        img_array = np.frombuffer(img_data, dtype=np.uint8)
        mask = cv2.imdecode(img_array, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            print("Error decoding the mask.")
            return None
        return mask
 No newline at end of file
+113 −1
Original line number Diff line number Diff line
import os
import base64
import requests
import numpy as np
import cv2
from typing import Union, List, Tuple
@@ -102,3 +105,112 @@ class ComicTextDetector(TextDetectorBase):
                else:
                    self.model.load_model(CTD_ONNX_PATH)
            self.model.detect_size = self.detect_size

@register_textdetectors('stariver_ocr')
class StariverDetector(TextDetectorBase):

    params = {
        'token': "Replace with your token",
        'expand_ratio': "0.01",
        'description': '星河云(团子翻译器) OCR 文字检测器'
    }

    @property
    def token(self):
        return self.params['token']
    
    @property
    def expand_ratio(self):
        return self.params['expand_ratio'].eval()
    
    def __init__(self, **params) -> None:
        self.url = 'https://dl.ap-sh.starivercs.cn/v2/manga_trans/advanced/manga_ocr'

    def detect(self, img: np.ndarray) -> Tuple[np.ndarray, List[TextBlock]]:
        if not self.token or self.token == 'Replace with your token':
            self.logger.error(f'token 没有设置。当前token:{self.token}')
            raise ValueError('token 没有设置。')
        img_encoded = cv2.imencode('.jpg', img)[1]
        img_base64 = base64.b64encode(img_encoded).decode('utf-8')
        
        payload = {
            "token": self.token,
            "mask": True,
            "refine": True,
            "filtrate": True,
            "disable_skip_area": True,
            "detect_scale": 3,
            "merge_threshold": 0.5,
            "low_accuracy_mode": False,
            "image": img_base64
        }
        response = requests.post(self.url, json=payload)
        response_data = response.json()['Data']

        blk_list = []
        for block in response_data.get('text_block', []):
            xyxy = [int(min(coord[0] for coord in block['block_coordinate'].values())),
                    int(min(coord[1] for coord in block['block_coordinate'].values())),
                    int(max(coord[0] for coord in block['block_coordinate'].values())),
                    int(max(coord[1] for coord in block['block_coordinate'].values()))]
            lines = [np.array([[coord[pos][0], coord[pos][1]] for pos in ['upper_left', 'upper_right', 'lower_right', 'lower_left']], dtype=np.float32) for coord in block['coordinate']]
            texts = block.get('texts', '')
            blk = TextBlock(
                xyxy=xyxy,
                lines=lines,
                language=block.get('language', 'unknown'),
                vertical=block.get('is_vertical', False),
                font_size=block.get('text_size', 0),
                distance=np.array([0, 0], dtype=np.float32),
                angle=0,
                vec=np.array([0, 0], dtype=np.float32),
                norm=0,
                merged=False,
                text=texts,
                fg_colors=np.array(block.get('foreground_color', [0, 0, 0]), dtype=np.float32),
                bg_colors=np.array(block.get('background_color', [0, 0, 0]), dtype=np.float32)
            )
            blk_list.append(blk)
        
        mask = self._decode_base64_mask(response_data['mask'])
        mask = self.expand_mask(mask)
        return mask, blk_list

    @staticmethod
    def _decode_base64_mask(base64_str: str) -> np.ndarray:
        img_data = base64.b64decode(base64_str)
        img_array = np.frombuffer(img_data, dtype=np.uint8)
        mask = cv2.imdecode(img_array, cv2.IMREAD_GRAYSCALE)
        return mask
    
    def expand_mask(self, mask: np.ndarray, expand_ratio: float = 0.01) -> np.ndarray:
        """
        在mask的原始部分上扩展mask,以便于提取更大的文字区域。
        :param mask: 输入的mask
        :param expand_ratio: 扩展比例,默认值为0.01
        :return: 扩展后的mask
        """
        # 确保mask是二值图像(只含0和255)
        mask = (mask > 0).astype(np.uint8) * 255

        # 获得图像的尺寸
        height, width = mask.shape
        
        # 计算kernel的大小(取图像尺寸的一部分,按比例expand_ratio)
        kernel_size = int(min(height, width) * expand_ratio)
        if kernel_size % 2 == 0:
            kernel_size += 1  # 确保kernel尺寸是奇数

        # 创建一个正方形的kernel
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        
        # 执行膨胀操作
        dilated_mask = cv2.dilate(mask, kernel, iterations=1)

        # 计算扩展后的mask
        dilated_mask = (dilated_mask > 0).astype(np.uint8) * 255
        
        return dilated_mask
    
    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
 No newline at end of file