Commit 0fb360c3 authored by dmMaze's avatar dmMaze
Browse files

add run module template & fix textlabel crash, close #972

parent 28e6eb5c
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -215,9 +215,9 @@ def main():

    # import msl.loadlib (required by translators/trans_eztrans) before init QApplication
    # yield QWindowsContext: OleInitialize() failed on py3.10, 
    from modules.base import load_modules
    from modules.base import init_module_registries
    from modules.prepare_local_files import prepare_local_files_forall
    load_modules()
    init_module_registries()
    prepare_local_files_forall()

    if not args.headless:
+10 −1
Original line number Diff line number Diff line
@@ -2,11 +2,20 @@ from .ocr import OCR, OCRBase
from .textdetector import TEXTDETECTORS, TextDetectorBase
from .translators import TRANSLATORS, BaseTranslator
from .inpaint import INPAINTERS, InpainterBase
from .base import DEFAULT_DEVICE, GPUINTENSIVE_SET
from .base import DEFAULT_DEVICE, GPUINTENSIVE_SET, LOGGER, merge_config_module_params, \
    init_module_registries, init_textdetector_registries, init_inpainter_registries, init_ocr_registries, init_translator_registries

GET_VALID_TEXTDETECTORS = lambda : list(TEXTDETECTORS.module_dict.keys())
GET_VALID_TRANSLATORS = lambda : list(TRANSLATORS.module_dict.keys())
GET_VALID_INPAINTERS = lambda : list(INPAINTERS.module_dict.keys())
GET_VALID_OCR = lambda : list(OCR.module_dict.keys())


MODULETYPE_TO_REGISTRIES = {
    'textdetector': TEXTDETECTORS,
    'ocr': OCR,
    'inpainter': INPAINTERS,
    'translator': TRANSLATORS
}

# TODO: use manga-image-translator as backend...
 No newline at end of file
+107 −8
Original line number Diff line number Diff line
@@ -33,6 +33,79 @@ def register_hooks(hooks_registered: OrderedDict, callbacks: Union[List, Callabl
            hooks_registered[hk] = callback
            nhooks += 1


def patch_module_params(cfg_param, module_params, module_name: str = ''):
    # cfg_param = config_params[module_key]
    cfg_key_set = set(cfg_param.keys())
    module_key_set = set(module_params.keys())
    for ck in cfg_key_set:
        if ck not in module_key_set:
            LOGGER.warning(f'Found invalid {module_name} config: {ck}')
            cfg_param.pop(ck)

    for mk in module_key_set:
        if mk not in cfg_key_set:
            if not mk.startswith('__') and mk != 'description':
                LOGGER.info(f'Found new {module_name} config: {mk}')
            cfg_param[mk] = module_params[mk]
        else:
            mparam = module_params[mk]
            cparam = cfg_param[mk]
            if isinstance(mparam, dict):
                tgt_type = type(mparam['value'])
                if isinstance(cparam, dict):
                    if 'value' in cparam:
                        v = cparam['value']
                    elif isinstance(mparam['value'], dict):
                        for k in mparam['value']:
                            if k in cparam:
                                mparam['value'][k] = cparam[k]
                        v = mparam['value']
                    else:
                        v = mparam['value']
                else:
                    v = cparam
                valid = True
                if tgt_type != type(v):
                    try:
                        v = tgt_type(v)
                    except:
                        valid = False
                        LOGGER.warning(f'Invalid param value {v} for defined dtype: {tgt_type}, it will be set to default value: {mparam}')
                if valid:
                    mparam['value'] = v
                cfg_param[mk] = mparam
            else:
                if type(cparam) != type(mparam):
                    if not isinstance(mparam, dict) and isinstance(cparam, dict):
                        cparam = cparam['value']
                    try:
                        cfg_param[mk] = type(mparam)(cparam)
                    except ValueError:
                        LOGGER.warning(f'Invalid param value {cparam} for defined dtype: {type(mparam)}, it will be set to default value: {mparam}')
                        cfg_param[mk] = mparam
    
    cfg_key_list = list(cfg_param.keys())
    module_key_list = list(module_params.keys())
    if cfg_key_list != module_key_list:
        new_params = {key: cfg_param[key] for key in module_key_list}
        cfg_param.clear()
        cfg_param.update(new_params)
        module_key_set = set(module_params.keys())
    cfg_param['__param_patched'] = True
    return cfg_param


def merge_config_module_params(config_params: Dict, module_keys: List, get_module: Callable) -> Dict:
    for module_key in module_keys:
        module_params = get_module(module_key).params
        if module_key not in config_params or config_params[module_key] is None:
            config_params[module_key] = module_params
        else:
            patch_module_params(config_params[module_key], module_params, module_key)
    return config_params


class BaseModule:

    params: Dict = None
@@ -47,6 +120,8 @@ class BaseModule:
    _load_model_keys: set = None

    def __init__(self, **params) -> None:
        if self.params is not None and '__param_patched' not in params:
            params = patch_module_params(params, self.params, self)
        if params:
            if self.params is None:
                self.params = params
@@ -228,7 +303,14 @@ TORCH_DTYPE_MAP = {
    'bf16': torch.bfloat16,
}

def load_modules():
MODULE_SCRIPTS = {
    'translator': {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'},
    'textdetector': {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'},
    'inpainter': {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'},
    'ocr': {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'},
}
    
def init_module_registries(target_modules=None):
    def _load_module(module_dir: str, module_pattern: str):
        modules = os.listdir(module_dir)
        pattern = re.compile(module_pattern)
@@ -243,10 +325,27 @@ def load_modules():
                except Exception as e:
                    LOGGER.warning(f'Failed to import {module}: {e}')

    for kwargs in [
        {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'},
        {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'},
        {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'},
        {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'},
    ]:
        _load_module(**kwargs)
    if target_modules is None:
        target_modules = MODULE_SCRIPTS
    if isinstance(target_modules, str):
        target_modules = [target_modules]

    for k in target_modules:
        _load_module(**MODULE_SCRIPTS[k])


def init_textdetector_registries():
    init_module_registries('textdetector')


def init_inpainter_registries():
    init_module_registries('inpainter')


def init_ocr_registries():
    init_module_registries('ocr')


def init_translator_registries():
    init_module_registries('translator')

scripts/run_module.py

0 → 100644
+66 −0
Original line number Diff line number Diff line
import click

import sys, os
import os.path as osp
sys.path.append(osp.dirname(osp.dirname(__file__)))

from tqdm import tqdm

from utils.config import load_config

from utils.shared import PROGRAM_PATH
from utils.textblock import visualize_textblocks
from utils.proj_imgtrans import ProjImgTrans
from utils.config import pcfg
from utils.io_utils import imread, imwrite
from modules import MODULETYPE_TO_REGISTRIES, init_translator_registries, init_inpainter_registries, init_ocr_registries, init_textdetector_registries


os.chdir(PROGRAM_PATH)


@click.group()
def cli():
    """text detector testing scripts.
    """



def init_module(module_type: str, module_name: str):
    assert module_type in MODULETYPE_TO_REGISTRIES
    module_cls = MODULETYPE_TO_REGISTRIES[module_type].get(module_name)
    module_cls_params = getattr(pcfg.module, module_type + '_params')
    module_params = module_cls_params.get(module_name, {})
    return module_cls(**module_params)


@cli.command('run_detector')
@click.option('--proj_dir')
@click.option('--detector', default=None)
@click.option('--config', default='config/config.json')
@click.option('--save_dir', default='tmp/test_ctd')
def run_detector(proj_dir, detector, config, save_dir):

    init_textdetector_registries()
    load_config(config)
    if detector is None:
        detector = pcfg.module.textdetector

    detector = init_module('textdetector', detector)
    print('detector params:', detector.params)

    proj = ProjImgTrans(proj_dir)
    for page_name in tqdm(proj.pages):
        blk_list = proj.pages[page_name]
        proj.set_current_img(page_name)
        mask, blk_list = detector.detect(proj.img_array, blk_list)
        blk_list = blk_list[:1]
        print(blk_list[0].get_text())
        vis = visualize_textblocks(proj.img_array, blk_list)
        imwrite(osp.join(save_dir, proj.current_img), vis, ext='.jpg')
        pass
    


if __name__ == '__main__':
    cli()
+2 −2
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ from modules.translators import MissingTranslatorParams
from modules.base import BaseModule, soft_empty_cache
from modules import INPAINTERS, TRANSLATORS, TEXTDETECTORS, OCR, \
    GET_VALID_TRANSLATORS, GET_VALID_TEXTDETECTORS, GET_VALID_INPAINTERS, GET_VALID_OCR, \
    BaseTranslator, InpainterBase, TextDetectorBase, OCRBase
    BaseTranslator, InpainterBase, TextDetectorBase, OCRBase, merge_config_module_params
import modules
modules.translators.SYSTEM_LANG = QLocale.system().name()
from utils.textblock import TextBlock, sort_regions
@@ -24,7 +24,7 @@ from utils.message import create_error_dialog, create_info_dialog
from .custom_widget import ImgtransProgressMessageBox, ParamComboBox
from .configpanel import ConfigPanel
from utils.proj_imgtrans import ProjImgTrans
from utils.config import pcfg, merge_config_module_params
from utils.config import pcfg
cfg_module = pcfg.module


Loading