Unverified Commit ef14a78f authored by dmMaze's avatar dmMaze Committed by GitHub
Browse files

Merge pull request #551 from bropines/lens_ocr

Update ocr_google_lens.py
parents 27c938a4 ec6a67ae
Loading
Loading
Loading
Loading
+100 −28
Original line number Diff line number Diff line
@@ -76,6 +76,60 @@ class LensAPI:
    def __init__(self):
        self.lens = Lens()

    @staticmethod
    def extract_text_and_coordinates(data):
        text_with_coords = []
        if isinstance(data, list):
            for item in data:
                if isinstance(item, list):
                    for sub_item in item:
                        if isinstance(sub_item, list) and len(sub_item) > 1 and isinstance(sub_item[0], str):
                            word = sub_item[0]
                            coords = sub_item[1]
                            if isinstance(coords, list) and all(isinstance(coord, (int, float)) for coord in coords):
                                text_with_coords.append({"text": word, "coordinates": coords})
                        else:
                            text_with_coords.extend(LensAPI.extract_text_and_coordinates(sub_item))
                else:
                    text_with_coords.extend(LensAPI.extract_text_and_coordinates(item))
        elif isinstance(data, dict):
            for value in data.values():
                text_with_coords.extend(LensAPI.extract_text_and_coordinates(value))
        return text_with_coords

    @staticmethod
    def stitch_text_smart(text_with_coords):
        # Преобразование Swap X and Y
        transformed_coords = [{'text': item['text'], 'coordinates': [item['coordinates'][1], item['coordinates'][0]]} for item in text_with_coords]
        sorted_elements = sorted(transformed_coords, key=lambda x: (round(x['coordinates'][1], 2), x['coordinates'][0]))

        stitched_text = []
        current_y = None
        current_line = []
        for element in sorted_elements:
            if current_y is None or abs(element['coordinates'][1] - current_y) > 0.05:
                if current_line:
                    stitched_text.append(" ".join(current_line))
                    current_line = []
                current_y = element['coordinates'][1]
            if element['text'] in [',', '.', '!', '?', ';', ':'] and current_line:
                current_line[-1] += element['text']
            else:
                current_line.append(element['text'])
        if current_line:
            stitched_text.append(" ".join(current_line))
        return "\n".join(stitched_text).strip()

    @staticmethod
    def stitch_text_sequential(text_with_coords):
        # Используем порядок элементов в исходном списке
        stitched_text = " ".join([element['text'] for element in text_with_coords])
        
        # Удаляем лишние пробелы вокруг знаков препинания
        stitched_text = re.sub(r'\s+([,?.!])', r'\1', stitched_text)
        
        return stitched_text.strip()

    @staticmethod
    def extract_full_text(data):
        try:
@@ -93,7 +147,7 @@ class LensAPI:
        except (IndexError, TypeError):
            return "Language not found in expected structure"

    def process_image(self, image_path=None, image_buffer=None):
    def process_image(self, image_path=None, image_buffer=None, response_method="Full Text"):
        if image_path:
            result = self.lens.scan_by_file(image_path)
        elif image_buffer:
@@ -101,11 +155,28 @@ class LensAPI:
        else:
            raise ValueError("Either image_path or image_buffer must be provided")

        text_with_coords = self.extract_text_and_coordinates(result['data'])

        if response_method == "Full Text":
            return {
                'full_text': self.extract_full_text(result['data']),
            'language': self.extract_language(result['data'])
                'language': self.extract_language(result['data']),
                'text_with_coordinates': text_with_coords
            }

        elif response_method == "Coordinate sequence":
            return {
                'full_text': self.stitch_text_sequential(text_with_coords),
                'language': self.extract_language(result['data']),
                'text_with_coordinates': text_with_coords
            }
        elif response_method == "Location coordinates":
            return {
                'full_text': self.stitch_text_smart(text_with_coords),
                'language': self.extract_language(result['data']),
                'text_with_coordinates': text_with_coords
            }
        else:
            raise ValueError("Invalid response method")

@register_OCR('google_lens')
class OCRLensAPI(OCRBase):
@@ -125,6 +196,16 @@ class OCRLensAPI(OCRBase):
            'value': False,
            'description': 'Convert text to lowercase except the first letter of each sentence'
        },
        'response_method': {
            'type': 'selector',
            'options': [
                'Full Text',
                'Coordinate sequence',
                'Location coordinates'
            ],
            'value': 'Full Text',
            'description': 'Choose the method for extracting text from image'
        },
        'description': 'OCR using Google Lens API'
    }
    
@@ -140,6 +221,10 @@ class OCRLensAPI(OCRBase):
    def no_uppercase(self):
        return self.get_param_value('no_uppercase')

    @property
    def response_method(self):
        return self.get_param_value('response_method')

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.api = LensAPI()
@@ -177,10 +262,9 @@ class OCRLensAPI(OCRBase):
                if self.debug_mode:
                    self.logger.info(f'Input image size: {img.shape}')
                _, buffer = cv2.imencode('.jpg', img)
                result = self.api.process_image(image_buffer=buffer.tobytes())
                result = self.api.process_image(image_buffer=buffer.tobytes(), response_method=self.response_method)
                if self.debug_mode:
                    self.logger.info(f'OCR result: {result}')
                # Check the result for the specified text
                ignore_texts = [
                    'Full text not found in expected structure',
                    'Full text not found(or Lens could not recognize it)'
@@ -191,12 +275,11 @@ class OCRLensAPI(OCRBase):
                if self.newline_handling == 'remove':
                    full_text = full_text.replace('\n', ' ')

                # Apply punctuation and spacing rules
                full_text = self._apply_punctuation_and_spacing(full_text)

                if self.no_uppercase:
                    full_text = self._apply_no_uppercase(full_text)

                full_text = self._apply_punctuation_and_spacing(full_text)

                if isinstance(full_text, list):
                    return '\n'.join(full_text)
                else:
@@ -212,32 +295,21 @@ class OCRLensAPI(OCRBase):

    def _apply_no_uppercase(self, text: str) -> str:
        def process_sentence(sentence):
            # Split the sentence into words, preserving punctuation
            words = re.findall(r'\S+|\s+', sentence)
            processed_words = []
            for i, word in enumerate(words):
                if i == 0 or words[i-1].strip() in '.!?…':
                    processed_words.append(word.capitalize())
                else:
                    processed_words.append(word.lower())
            return ''.join(processed_words)
            words = sentence.split()
            if not words:
                return ''
            processed = [words[0].capitalize()] + [word.lower() for word in words[1:]]
            return ' '.join(processed)

        # Split the text into sentences, preserving original spacing and punctuation
        sentences = re.split(r'(?<=[.!?…])', text)
        sentences = re.split(r'(?<=[.!?…])\s+', text)
        processed_sentences = [process_sentence(sentence) for sentence in sentences]
        
        return ' '.join(processed_sentences)

    def _apply_punctuation_and_spacing(self, text: str) -> str:
        # Remove extra spaces before punctuation
        text = re.sub(r'\s+([,.!?…])', r'\1', text)
        
        # Ensure single space after punctuation, except for multiple punctuation marks
        text = re.sub(r'([,.!?…])(?!\s)(?![,.!?…])', r'\1 ', text)
        
        # Remove space between multiple punctuation marks
        text = re.sub(r'([,.!?…])\s+([,.!?…])', r'\1\2', text)
        
        return text.strip()

    def _respect_delay(self):