Commit e4667a0b authored by dmMaze's avatar dmMaze
Browse files

Merge branch 'dev' of https://github.com/dmMaze/BallonsTranslator into dev

parents c9b21e36 db58002e
Loading
Loading
Loading
Loading
+6 −5
Original line number Diff line number Diff line
@@ -113,6 +113,7 @@ class BaseModule:
        model_deleted = False
        if self._load_model_keys is not None:
            for k in self._load_model_keys:
                if hasattr(self, k):
                    model = getattr(self, k)
                    if model is not None:
                        del model
+281 −0
Original line number Diff line number Diff line
import numpy as np
import time
import cv2
import re
from typing import List

import requests
from PIL import Image
import io
import time
# import json5
import json5
import lxml.html
import http.cookiejar as cookielib
import logging

from .base import register_OCR, OCRBase, TextBlock

class LensCore:
    LENS_ENDPOINT = 'https://lens.google.com/v3/upload'
    SUPPORTED_MIMES = [
        'image/x-icon', 'image/bmp', 'image/jpeg',
        'image/png', 'image/tiff', 'image/webp', 'image/heic'
    ]
    HEADERS = {
        'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7',
        'Accept-Encoding': 'gzip, deflate, br',
        'Accept-Language': 'en-US,en;q=0.9',
        'Origin': 'https://lens.google.com',
        'Referer': 'https://lens.google.com/',
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
    }

    def __init__(self):
        self.cookie_jar = cookielib.CookieJar()
        self.session = requests.Session()
        self.session.cookies = self.cookie_jar
        self.logger = logging.getLogger('LensCore')
        self.logger.setLevel(logging.INFO)
        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        ch.setFormatter(formatter)
        self.logger.addHandler(ch)

    def scan_by_data(self, data, mime, dimensions):
        headers = self.HEADERS.copy()
        files = {
            'encoded_image': ('image.jpg', data, mime),
            'original_width': (None, str(dimensions[0])),
            'original_height': (None, str(dimensions[1])),
            'processed_image_dimensions': (None, f"{dimensions[0]},{dimensions[1]}")
        }
        response = self.session.post(self.LENS_ENDPOINT, headers=headers, files=files)
        if response.status_code != 200:
            raise Exception(f"Failed to upload image. Status code: {response.status_code}")
        
        tree = lxml.html.parse(io.StringIO(response.text))
        r = tree.xpath("//script[@class='ds:1']")
        return json5.loads(r[0].text[len("AF_initDataCallback("):-2])


class Lens(LensCore):
    @staticmethod
    def resize_image(image, max_size=(1000, 1000)):
        image.thumbnail(max_size)
        if image.mode == 'RGBA':
            image = image.convert('RGB')
        buffer = io.BytesIO()
        image.save(buffer, format="JPEG")
        return buffer.getvalue(), image.size

    def scan_by_file(self, file_path):
        with Image.open(file_path) as img:
            img_data, dimensions = self.resize_image(img)
        return self.scan_by_data(img_data, 'image/jpeg', dimensions)

    def scan_by_buffer(self, buffer):
        img = Image.open(io.BytesIO(buffer))
        img_data, dimensions = self.resize_image(img)
        return self.scan_by_data(img_data, 'image/jpeg', dimensions)


class LensAPI:
    def __init__(self):
        self.lens = Lens()
        self.logger = logging.getLogger('LensAPI')
        self.logger.setLevel(logging.INFO)
        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        ch.setFormatter(formatter)
        self.logger.addHandler(ch)

    @staticmethod
    def extract_full_text(data):
        try:
            text_data = data[3][4][0][0]
            if isinstance(text_data, list):
                return "\n".join(text_data)
            return text_data
        except (IndexError, TypeError):
            return "Full text not found(or Lens could not recognize it)"

    @staticmethod
    def extract_language(data):
        try:
            return data[3][3]
        except (IndexError, TypeError):
            return "Language not found in expected structure"

    def process_image(self, image_path=None, image_buffer=None):
        if image_path:
            result = self.lens.scan_by_file(image_path)
        elif image_buffer:
            result = self.lens.scan_by_buffer(image_buffer)
        else:
            raise ValueError("Either image_path or image_buffer must be provided")

        return {
            'full_text': self.extract_full_text(result['data']),
            'language': self.extract_language(result['data'])
        }


@register_OCR('lens_api_ocr')
class OCRLensAPI(OCRBase):
    params = {
        "delay": 1.0,
        'debug': {
            'type': 'checkbox',
            'value': False,
            'description': 'Enable debug logging'
        },
        'newline_handling': {
            'type': 'selector',
            'options': [
                'preserve',
                'remove'
            ],
            'value': 'preserve',
            'description': 'Choose how to handle newline characters in OCR result'
        },
        'no_uppercase': {
            'type': 'checkbox',
            'value': False,
            'description': 'Convert text to lowercase except the first letter of each sentence'
        },
        'description': 'OCR using Google Lens API'
    }
    
    @property
    def request_delay(self):
        return float(self.params['delay']['value'])
    
    @property
    def debug_mode(self):
        return bool(self.params['debug']['value'])
    
    @property
    def newline_handling(self):
        return self.params['newline_handling']['value']
    
    @property
    def no_uppercase(self):
        return bool(self.params['no_uppercase']['value'])

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.api = LensAPI()
        self.last_request_time = 0

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock], *args, **kwargs):
        im_h, im_w = img.shape[:2]
        if self.debug_mode:
            self.logger.info(f'Image size: {im_h}x{im_w}')
        for blk in blk_list:
            x1, y1, x2, y2 = blk.xyxy
            if self.debug_mode:
                self.logger.info(f'Processing block: ({x1, y1, x2, y2})')
            if y2 < im_h and x2 < im_w and x1 > 0 and y1 > 0 and x1 < x2 and y1 < y2:
                cropped_img = img[y1:y2, x1:x2]
                if self.debug_mode:
                    self.logger.info(f'Cropped image size: {cropped_img.shape}')
                blk.text = self.ocr(cropped_img)
            else:
                if self.debug_mode:
                    self.logger.warning('Invalid text bbox to target image')
                blk.text = ['']

    def ocr_img(self, img: np.ndarray) -> str:
        if self.debug_mode:
            self.logger.debug(f'ocr_img: {img.shape}')
        return self.ocr(img)

    def ocr(self, img: np.ndarray) -> str:
        if self.debug_mode:
            self.logger.info(f'Starting OCR on image of shape: {img.shape}')
        self._respect_delay()
        try:
            if img.size > 0:  # Check if the image is not empty
                if self.debug_mode:
                    self.logger.info(f'Input image size: {img.shape}')
                _, buffer = cv2.imencode('.jpg', img)
                result = self.api.process_image(image_buffer=buffer.tobytes())
                if self.debug_mode:
                    self.logger.info(f'OCR result: {result}')
                # Check the result for the specified text
                ignore_texts = [
                    'Full text not found in expected structure',
                    'Full text not found(or Lens could not recognize it)'
                ]
                if result['full_text'] in ignore_texts:
                    return ''
                full_text = result['full_text']
                if self.newline_handling == 'remove':
                    full_text = full_text.replace('\n', ' ')
                
                # Apply punctuation and spacing rules
                full_text = self._apply_punctuation_and_spacing(full_text)

                if self.no_uppercase:
                    full_text = self._apply_no_uppercase(full_text)

                if isinstance(full_text, list):
                    return '\n'.join(full_text)
                else:
                    return full_text
            else:
                if self.debug_mode:
                    self.logger.warning('Empty image provided for OCR')
                return ''
        except Exception as e:
            if self.debug_mode:
                self.logger.error(f"OCR error: {str(e)}")
            return ''

    def _apply_no_uppercase(self, text: str) -> str:
        def process_sentence(sentence):
            # Split the sentence into words, preserving punctuation
            words = re.findall(r'\S+|\s+', sentence)
            processed_words = []
            for i, word in enumerate(words):
                if i == 0 or words[i-1].strip() in '.!?…':
                    processed_words.append(word.capitalize())
                else:
                    processed_words.append(word.lower())
            return ''.join(processed_words)

        # Split the text into sentences, preserving original spacing and punctuation
        sentences = re.split(r'(?<=[.!?…])', text)
        processed_sentences = [process_sentence(sentence) for sentence in sentences]
        
        return ''.join(processed_sentences)

    def _apply_punctuation_and_spacing(self, text: str) -> str:
        # Remove extra spaces before punctuation
        text = re.sub(r'\s+([,.!?…])', r'\1', text)
        
        # Ensure single space after punctuation, except for multiple punctuation marks
        text = re.sub(r'([,.!?…])(?!\s)(?![,.!?…])', r'\1 ', text)
        
        # Remove space between multiple punctuation marks
        text = re.sub(r'([,.!?…])\s+([,.!?…])', r'\1\2', text)
        
        return text.strip()

    def _respect_delay(self):
        current_time = time.time()
        time_since_last_request = current_time - self.last_request_time
        if self.debug_mode:
            self.logger.info(f'Time since last request: {time_since_last_request} seconds')
        if time_since_last_request < self.request_delay:
            sleep_time = self.request_delay - time_since_last_request
            if self.debug_mode:
                self.logger.info(f'Sleeping for {sleep_time} seconds')
            time.sleep(sleep_time)
        self.last_request_time = time.time()

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
+78 −0
Original line number Diff line number Diff line
import requests
import hashlib
import time
import uuid
from typing import List, Dict
from .base import *

@register_translator('Youdao')
class YoudaoTranslator(BaseTranslator):
    concate_text = False
    cht_require_convert = True
    params: Dict = {
        'api_key': '',
        'app_secret': '',
    }

    @property
    def api_key(self) -> str:
        return self.params['api_key']

    @property
    def app_secret(self) -> str:
        return self.params['app_secret']

    def _setup_translator(self):
        self.lang_map['简体中文'] = 'zh-CHS'
        self.lang_map['English'] = 'en'
        self.lang_map['日本語'] = 'ja'
        self.lang_map['한국어'] = 'ko'
        # Add more language mappings as needed

    def generate_input(self, query: str) -> str:
        if len(query) > 20:
            input_str = query[:10] + str(len(query)) + query[-10:]
        else:
            input_str = query
        return input_str

    def generate_sign(self, query: str, salt: str, curtime: str) -> str:
        input_str = self.generate_input(query)
        sign_str = self.api_key + input_str + salt + curtime + self.app_secret
        hash_algorithm = hashlib.sha256()
        hash_algorithm.update(sign_str.encode('utf-8'))
        return hash_algorithm.hexdigest()

    def _translate(self, src_list: List[str]) -> List[str]:
        url = "https://openapi.youdao.com/api"
        results = []
        for query in src_list:
            salt = str(uuid.uuid4())
            curtime = str(int(time.time()))
            sign = self.generate_sign(query, salt, curtime)

            payload = {
                'q': query,
                'from': self.lang_map[self.lang_source],
                'to': self.lang_map[self.lang_target],
                'appKey': self.api_key,
                'salt': salt,
                'sign': sign,
                'signType': 'v3',
                'curtime': curtime,
            }

            headers = {
                'Content-Type': 'application/x-www-form-urlencoded'
            }

            response = requests.post(url, data=payload, headers=headers)
            response_data = response.json()

            if 'translation' in response_data:
                results.append(response_data['translation'][0])
            else:
                results.append('')

        return results
+47 −19
Original line number Diff line number Diff line
@@ -6,6 +6,10 @@ from utils.logger import logger as LOGGER
from .stylewidgets import ConfigComboBox, NoBorderPushBtn, CustomComboBox
from utils.shared import CONFIG_FONTSIZE_CONTENT, CONFIG_COMBOBOX_MIDEAN, CONFIG_COMBOBOX_LONG, CONFIG_COMBOBOX_SHORT, CONFIG_COMBOBOX_HEIGHT
from utils.config import pcfg
import logging

# Настройка логгера
logging.basicConfig(level=logging.DEBUG)

from qtpy.QtWidgets import QPlainTextEdit, QHBoxLayout, QVBoxLayout, QWidget, QLabel, QComboBox, QCheckBox, QLineEdit, QGridLayout, QPushButton
from qtpy.QtCore import Qt, Signal
@@ -160,6 +164,7 @@ class ParamWidget(QWidget):
            require_label = True
            is_str = isinstance(params[param_key], str)
            is_digital = isinstance(params[param_key], float) or isinstance(params[param_key], int)
            param_widget = None  # Инициализация переменной

            if isinstance(params[param_key], bool):
                param_widget = ParamCheckBox(param_key)
@@ -179,7 +184,10 @@ class ParamWidget(QWidget):
                param_dict = params[param_key]
                display_param_name = get_param_display_name(param_key, param_dict)
                value = params[param_key]['value']
                if param_dict['type'] == 'selector':
                param_widget = None  # Ensure initialization
                param_type = param_dict['type'] if 'type' in param_dict else 'line_editor'

                if param_type == 'selector':
                    if 'url' in param_key:
                        size = CONFIG_COMBOBOX_MIDEAN
                    else:
@@ -187,8 +195,6 @@ class ParamWidget(QWidget):

                    param_widget = ParamComboBox(param_key, param_dict['options'], size=size, scrollWidget=scrollWidget)

                    # if cuda is not available, disable combobox 'cuda' item
                    # https://stackoverflow.com/questions/38915001/disable-specific-items-in-qcombobox
                    if param_key == 'device' and DEFAULT_DEVICE == 'cpu':
                        param_dict['value'] = 'cpu'
                        for ii, device in enumerate(param_dict['options']):
@@ -197,22 +203,28 @@ class ParamWidget(QWidget):
                                item = model.item(ii, 0)
                                item.setEnabled(False)
                    param_widget.setCurrentText(str(value))
                    param_widget.paramwidget_edited.connect(self.on_paramwidget_edited)
                elif param_dict['type'] == 'editor':

                elif param_type == 'editor':
                    param_widget = ParamEditor(param_key)
                    param_widget.setText(value)
                    param_widget.paramwidget_edited.connect(self.on_paramwidget_edited)
                elif param_dict['type'] == 'checkbox':

                elif param_type == 'checkbox':
                    param_widget = ParamCheckBox(param_key)
                    if isinstance(value, str):
                        value = value.lower().strip() == 'true'
                        params[param_key]['value'] = value
                    param_widget.setChecked(value)
                    param_widget.paramwidget_edited.connect(self.on_paramwidget_edited)
                elif param_dict['type'] == 'pushbtn':

                elif param_type == 'pushbtn':
                    param_widget = ParamPushButton(param_key, param_dict)
                    param_widget.paramwidget_edited.connect(self.on_paramwidget_edited)
                    require_label = False

                elif param_type == 'line_editor':
                    param_widget = ParamLineEditor(param_key, force_digital=is_digital)
                    param_widget.setText(str(value))

                if param_widget is not None:
                    param_widget.paramwidget_edited.connect(self.on_paramwidget_edited)
                    if 'description' in param_dict:
                        param_widget.setToolTip(param_dict['description'])

@@ -221,12 +233,28 @@ class ParamWidget(QWidget):
                param_label = ParamNameLabel(display_param_name)
                param_layout.addWidget(param_label, ii, 0)
                widget_idx = 1
            if param_widget:
                param_layout.addWidget(param_widget, ii, widget_idx)
            else:
                raise ValueError(f"Failed to initialize widget for key: {param_key}")

    def on_paramwidget_edited(self, param_key, param_content):
        content_dict = {'content': param_content}
        self.paramwidget_edited.emit(param_key, content_dict)

class ModuleParseWidgets(QWidget):
    def addModulesParamWidgets(self, ocr_instance):
        self.params = ocr_instance.get_params()
        self.on_module_changed()

    def on_module_changed(self):
        self.updateModuleParamWidget()

    def updateModuleParamWidget(self):
        widget = ParamWidget(self.params, scrollWidget=self)
        layout = QVBoxLayout()
        layout.addWidget(widget)
        self.setLayout(layout)

class ModuleConfigParseWidget(QWidget):
    module_changed = Signal(str)