Commit fc247d8d authored by dmMaze's avatar dmMaze
Browse files

ocr & textdetector refractory

parent 1fc67df5
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
__pycache__
local_*
inpainted
mask
result
+18 −12
Original line number Diff line number Diff line
@@ -98,19 +98,25 @@ def commit_hash():
    return stored_commit_hash


TRANSLATOR_DIR = 'modules/translators'
TRANSLATOR_PATTERN = re.compile(r'trans_(.*?).py') 
def load_translators(translators = None):
    if translators is None:
        translators = os.listdir(TRANSLATOR_DIR)

    for translator in translators:
        if TRANSLATOR_PATTERN.match(translator) is not None:
            importlib.import_module('modules.translators.' + translator.replace('.py', ''))


def load_modules():
    load_translators()

    def _load_module(module_dir: str, module_pattern: str):
        modules = os.listdir(module_dir)
        pattern = re.compile(module_pattern)
        module_path = module_dir.replace('/', '.')
        if not module_path.endswith('.'):
            module_path += '.'
        for module_name in modules:
            if pattern.match(module_name) is not None:
                importlib.import_module(module_path + module_name.replace('.py', ''))

    for kwargs in [
        {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'},
        {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'},
        {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'},
        {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'},
    ]:
        _load_module(**kwargs)

BT = None
APP = None
+3 −3
Original line number Diff line number Diff line
from .ocr import OCR, OCRBase, OCRMIT32px, OCRMIT48pxCTC, MangaOCR
from .textdetector import TEXTDETECTORS, TextDetectorBase, ComicTextDetector
from .ocr import OCR, OCRBase
from .textdetector import TEXTDETECTORS, TextDetectorBase
from .translators import TRANSLATORS, BaseTranslator
from .inpaint import INPAINTERS, InpainterBase, PatchmatchInpainter, AOTInpainter, LamaInpainterMPE
from .inpaint import INPAINTERS, InpainterBase
from .base import DEFAULT_DEVICE, GPUINTENSIVE_SET

GET_VALID_TEXTDETECTORS = lambda : list(TEXTDETECTORS.module_dict.keys())
+3 −323
Original line number Diff line number Diff line
from typing import Tuple, List, Dict, Union, Callable
from ordered_set import OrderedSet
from typing import List
import numpy as np
import logging
from collections import OrderedDict

from utils.textblock import TextBlock

from utils.registry import Registry
OCR = Registry('OCR')
register_OCR = OCR.register_module

from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, LOGGER

class OCRBase(BaseModule):

    _postprocess_hooks = OrderedDict()
    _preprocess_hooks = OrderedDict()

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.name = ''
        for key in OCR.module_dict:
            if OCR.module_dict[key] == self.__class__:
                self.name = key
                break

    def run_ocr(self, img: np.ndarray, blk_list: List[TextBlock] = None) -> Union[List[TextBlock], str]:

        if not self.all_model_loaded():
            self.load_model()

        if blk_list is None:
            text = self.ocr_img(img)
            return text
        elif isinstance(blk_list, TextBlock):
            blk_list = [blk_list]

        for blk in blk_list:
            if self.name != 'none_ocr':
                blk.text = []
                
        self._ocr_blk_list(img, blk_list)
        for callback_name, callback in self._postprocess_hooks.items():
            callback(textblocks=blk_list, img=img, ocr_module=self)
        # for blk in blk_list:
        #     if isinstance(blk.text, List):
        #         for ii, t in enumerate(blk.text):
        #             for callback in self.postprocess_hooks:
        #                 blk.text[ii] = callback(t, blk=blk)
        #     else:
        #         for callback in self.postprocess_hooks:
        #             blk.text = callback(blk.text, blk=blk)
        return blk_list

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]) -> None:
        raise NotImplementedError

    def ocr_img(self, img: np.ndarray) -> str:
        raise NotImplementedError

from .base import OCRBase, register_OCR, DEVICE_SELECTOR, DEFAULT_DEVICE, TextBlock, OCR

from .model_32px import OCR32pxModel
@register_OCR('mit32px')
@@ -110,54 +53,6 @@ class OCRMIT32px(OCRBase):
        self.model.max_chunk_size = chunk_size



@register_OCR('manga_ocr')
class MangaOCR(OCRBase):
    params = {
        'device': DEVICE_SELECTOR()
    }
    device = DEFAULT_DEVICE

    download_file_list = [{
        'url': 'https://huggingface.co/kha-white/manga-ocr-base/resolve/main/',
        'files': ['pytorch_model.bin', 'config.json', 'preprocessor_config.json', 'README.md', 'special_tokens_map.json', 'tokenizer_config.json', 'vocab.txt'],
        'sha256_pre_calculated': ['c63e0bb5b3ff798c5991de18a8e0956c7ee6d1563aca6729029815eda6f5c2eb', None, None, None, None, None, None],
        'save_dir': 'data/models/manga-ocr-base',
        'concatenate_url_filename': 1,
    }]
    _load_model_keys = {'model'}

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.device = self.params['device']['value']
        self.model: MangaOCR = None

    def _load_model(self):
        from .manga_ocr import MangaOcr
        self.model = MangaOcr(device=self.device)

    def ocr_img(self, img: np.ndarray) -> str:
        return self.model(img)

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        im_h, im_w = img.shape[:2]
        for blk in blk_list:
            x1, y1, x2, y2 = blk.xyxy
            if y2 < im_h and x2 < im_w and \
                x1 > 0 and y1 > 0 and x1 < x2 and y1 < y2: 
                blk.text = self.model(img[y1:y2, x1:x2])
            else:
                logging.warning('invalid textbbox to target img')
                blk.text = ['']

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
        device = self.params['device']['value']
        if self.device != device:
            self.model.to(device)



from .mit48px_ctc import OCR48pxCTC
@register_OCR('mit48px_ctc')
class OCRMIT48pxCTC(OCRBase):
@@ -208,7 +103,6 @@ class OCRMIT48pxCTC(OCRBase):
        self.model.max_chunk_size = chunk_size



from .mit48px import Model48pxOCR
OCR48PXMODEL_PATH = r'data/models/ocr_ar_48px.ckpt'
@register_OCR('mit48px')
@@ -243,217 +137,3 @@ class OCRMIT48px(OCRBase):
        device = self.params['device']['value']
        if self.device != device:
            self.model.to(device)
 No newline at end of file

from .stariver_ocr import StariverOCR
@register_OCR('stariver_ocr')
class OCRStariver(OCRBase):
    params = {
        'token': 'Replace with your token',
        "refine":{
            'type': 'checkbox',
            'value': True
        },
        "filtrate":{
            'type': 'checkbox',
            'value': True
        },
        "disable_skip_area":{
            'type': 'checkbox',
            'value': True
        },
        "detect_scale": "3",
        "merge_threshold": "2",
        "force_expand":{
            'type': 'checkbox',
            'value': False,
            'description': '是否强制扩展图片像素,会导致识别速度下降'
        },
        'description': '星河云(团子翻译器) OCR API'
    }

    @property
    def token(self):
        return self.params['token']
    
    @property
    def expand_ratio(self):
        return float(self.params['expand_ratio'])
    
    @property
    def refine(self):
        return  self.params['refine']['value']
     
    @property
    def filtrate(self):
        self.params['filtrate']['value']
        
    @property
    def disable_skip_area(self):
        return self.params['disable_skip_area']['value']

    @property
    def detect_scale(self):
        return int(self.params['detect_scale'])
    
    @property
    def merge_threshold(self):
        return float(self.params['merge_threshold'])
    
    @property
    def force_expand(self):
        return self.params['force_expand']['value']

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.client = StariverOCR(self.token, refine=self.refine, filtrate=self.filtrate, disable_skip_area=self.disable_skip_area, detect_scale=self.detect_scale, merge_threshold=self.merge_threshold, force_expand=self.force_expand)

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        im_h, im_w = img.shape[:2]
        for blk in blk_list:
            x1, y1, x2, y2 = blk.xyxy
            if y2 < im_h and x2 < im_w and \
                    x1 > 0 and y1 > 0 and x1 < x2 and y1 < y2:
                blk.text = self.client.ocr(img[y1:y2, x1:x2])
            else:
                logging.warning('invalid textbbox to target img')
                blk.text = ['']

    def ocr_img(self, img: np.ndarray) -> str:
        self.logger.debug(f'ocr_img: {img.shape}')
        return self.client.ocr(img)

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
        self.client.token = self.params['token']

@register_OCR('none_ocr')
class OCRNone(OCRBase):
    def __init__(self, **params) -> None:
        super().__init__(**params)

    params = {
        'NOTICE': 'Not a OCR, just return original text.',
        'description': 'Not a OCR, just return original text.'
    }

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        pass

    def ocr_img(self, img: np.ndarray) -> str:
        return ''
    
import platform
if platform.mac_ver()[0] >= '10.15':
    from .macos_ocr import get_supported_languages

    macos_ocr_supported_languages = get_supported_languages()

    if len(macos_ocr_supported_languages) > 0:
        @register_OCR('macos_ocr')
        class OCRApple(OCRBase):
            params = {
                'language': {
                    'type':'selector',
                    'options': list(get_supported_languages()[0]),
                    'value': 'en-US',
                },
                # While this does appear 
                # it doesn't update the languages available
                # different recog level, different available langs
                # 'recognition_level': {
                #     'type': 'selector',
                #     'options': [
                #         'accurate',
                #         'fast',
                #     ],
                #     'value': 'accurate',
                # },
                'confidence_level': '0.1',
            }
            language = 'en-US'
            recognition = 'accurate'
            confidence = '0.1'

            def __init__(self, **params) -> None:
                super().__init__(**params)
                from .macos_ocr import AppleOCR
                self.model = AppleOCR(lang=[self.language])

            def ocr_img(self, img: np.ndarray) -> str:
                return self.model(img)

            def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
                im_h, im_w = img.shape[:2]
                for blk in blk_list:
                    x1, y1, x2, y2 = blk.xyxy
                    if y2 < im_h and x2 < im_w and \
                        x1 > 0 and y1 > 0 and x1 < x2 and y1 < y2: 
                        blk.text = self.model(img[y1:y2, x1:x2])
                    else:
                        logging.warning('invalid textbbox to target img')
                        blk.text = ['']

            def updateParam(self, param_key: str, param_content):
                super().updateParam(param_key, param_content)
                self.language = self.params['language']['value']
                self.model.lang = [self.language]

                # self.recognition = self.params['recognition_level']['value']
                # self.model.recog_level = self.recognition
                # self.params['language']['options'] = list(get_supported_languages(self.recognition)[0])

                self.confidence = self.params['confidence_level']
                self.model.min_confidence = self.confidence
    else:
        LOGGER.warning(f'No supported language packs found for MacOS, MacOS OCR will be unavailable.')
                

if platform.system() == 'Windows' and platform.version() >= '10.0.10240.0':
    from .windows_ocr import winocr_available_recognizer_languages

    if len(winocr_available_recognizer_languages) > 0:

        languages_display_name = [lang.display_name for lang in winocr_available_recognizer_languages]
        languages_tag = [lang.language_tag for lang in winocr_available_recognizer_languages]
        @register_OCR('windows_ocr')
        class OCRWindows(OCRBase):
            params = {
                'language': {
                    'type':'selector',
                    'options': languages_display_name,
                    'value': languages_display_name[0],
                }
            }
            language = languages_display_name[0]

            def __init__(self, **params) -> None:
                super().__init__(**params)
                from .windows_ocr import WindowsOCR
                self.engine = WindowsOCR()
                self.engine.lang = self.get_engine_lang()

            def get_engine_lang(self) -> str:
                language = self.params['language']['value'] 
                tag_name = languages_tag[languages_display_name.index(language)]
                return tag_name

            def ocr_img(self, img: np.ndarray) -> str:
                self.engine(img)

            def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]) -> None:
                im_h, im_w = img.shape[:2]
                for blk in blk_list:
                    x1, y1, x2, y2 = blk.xyxy
                    if y2 < im_h and x2 < im_w and \
                        x1 > 0 and y1 > 0 and x1 < x2 and y1 < y2: 
                        blk.text = self.engine(img[y1:y2, x1:x2])
                    else:
                        logging.warning('invalid textbbox to target img')
                        blk.text = ['']
            
            def updateParam(self, param_key: str, param_content):
                super().updateParam(param_key, param_content)
                self.engine.lang = self.get_engine_lang()

    else:
        LOGGER.warning(f'No supported language packs found for windows, Windows OCR will be unavailable.')
 No newline at end of file

modules/ocr/base.py

0 → 100644
+58 −0
Original line number Diff line number Diff line
from typing import Tuple, List, Dict, Union, Callable
import numpy as np
from collections import OrderedDict

from utils.textblock import TextBlock

from utils.registry import Registry
OCR = Registry('OCR')
register_OCR = OCR.register_module

from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, LOGGER

class OCRBase(BaseModule):

    _postprocess_hooks = OrderedDict()
    _preprocess_hooks = OrderedDict()

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.name = ''
        for key in OCR.module_dict:
            if OCR.module_dict[key] == self.__class__:
                self.name = key
                break

    def run_ocr(self, img: np.ndarray, blk_list: List[TextBlock] = None) -> Union[List[TextBlock], str]:

        if not self.all_model_loaded():
            self.load_model()

        if blk_list is None:
            text = self.ocr_img(img)
            return text
        elif isinstance(blk_list, TextBlock):
            blk_list = [blk_list]

        for blk in blk_list:
            if self.name != 'none_ocr':
                blk.text = []
                
        self._ocr_blk_list(img, blk_list)
        for callback_name, callback in self._postprocess_hooks.items():
            callback(textblocks=blk_list, img=img, ocr_module=self)
        # for blk in blk_list:
        #     if isinstance(blk.text, List):
        #         for ii, t in enumerate(blk.text):
        #             for callback in self.postprocess_hooks:
        #                 blk.text[ii] = callback(t, blk=blk)
        #     else:
        #         for callback in self.postprocess_hooks:
        #             blk.text = callback(blk.text, blk=blk)
        return blk_list

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]) -> None:
        raise NotImplementedError

    def ocr_img(self, img: np.ndarray) -> str:
        raise NotImplementedError
Loading