Loading launch.py +2 −2 Original line number Diff line number Diff line Loading @@ -215,9 +215,9 @@ def main(): # import msl.loadlib (required by translators/trans_eztrans) before init QApplication # yield QWindowsContext: OleInitialize() failed on py3.10, from modules.base import load_modules from modules.base import init_module_registries from modules.prepare_local_files import prepare_local_files_forall load_modules() init_module_registries() prepare_local_files_forall() if not args.headless: Loading modules/__init__.py +10 −1 Original line number Diff line number Diff line Loading @@ -2,11 +2,20 @@ from .ocr import OCR, OCRBase from .textdetector import TEXTDETECTORS, TextDetectorBase from .translators import TRANSLATORS, BaseTranslator from .inpaint import INPAINTERS, InpainterBase from .base import DEFAULT_DEVICE, GPUINTENSIVE_SET from .base import DEFAULT_DEVICE, GPUINTENSIVE_SET, LOGGER, merge_config_module_params, \ init_module_registries, init_textdetector_registries, init_inpainter_registries, init_ocr_registries, init_translator_registries GET_VALID_TEXTDETECTORS = lambda : list(TEXTDETECTORS.module_dict.keys()) GET_VALID_TRANSLATORS = lambda : list(TRANSLATORS.module_dict.keys()) GET_VALID_INPAINTERS = lambda : list(INPAINTERS.module_dict.keys()) GET_VALID_OCR = lambda : list(OCR.module_dict.keys()) MODULETYPE_TO_REGISTRIES = { 'textdetector': TEXTDETECTORS, 'ocr': OCR, 'inpainter': INPAINTERS, 'translator': TRANSLATORS } # TODO: use manga-image-translator as backend... No newline at end of file modules/base.py +107 −8 Original line number Diff line number Diff line Loading @@ -33,6 +33,79 @@ def register_hooks(hooks_registered: OrderedDict, callbacks: Union[List, Callabl hooks_registered[hk] = callback nhooks += 1 def patch_module_params(cfg_param, module_params, module_name: str = ''): # cfg_param = config_params[module_key] cfg_key_set = set(cfg_param.keys()) module_key_set = set(module_params.keys()) for ck in cfg_key_set: if ck not in module_key_set: LOGGER.warning(f'Found invalid {module_name} config: {ck}') cfg_param.pop(ck) for mk in module_key_set: if mk not in cfg_key_set: if not mk.startswith('__') and mk != 'description': LOGGER.info(f'Found new {module_name} config: {mk}') cfg_param[mk] = module_params[mk] else: mparam = module_params[mk] cparam = cfg_param[mk] if isinstance(mparam, dict): tgt_type = type(mparam['value']) if isinstance(cparam, dict): if 'value' in cparam: v = cparam['value'] elif isinstance(mparam['value'], dict): for k in mparam['value']: if k in cparam: mparam['value'][k] = cparam[k] v = mparam['value'] else: v = mparam['value'] else: v = cparam valid = True if tgt_type != type(v): try: v = tgt_type(v) except: valid = False LOGGER.warning(f'Invalid param value {v} for defined dtype: {tgt_type}, it will be set to default value: {mparam}') if valid: mparam['value'] = v cfg_param[mk] = mparam else: if type(cparam) != type(mparam): if not isinstance(mparam, dict) and isinstance(cparam, dict): cparam = cparam['value'] try: cfg_param[mk] = type(mparam)(cparam) except ValueError: LOGGER.warning(f'Invalid param value {cparam} for defined dtype: {type(mparam)}, it will be set to default value: {mparam}') cfg_param[mk] = mparam cfg_key_list = list(cfg_param.keys()) module_key_list = list(module_params.keys()) if cfg_key_list != module_key_list: new_params = {key: cfg_param[key] for key in module_key_list} cfg_param.clear() cfg_param.update(new_params) module_key_set = set(module_params.keys()) cfg_param['__param_patched'] = True return cfg_param def merge_config_module_params(config_params: Dict, module_keys: List, get_module: Callable) -> Dict: for module_key in module_keys: module_params = get_module(module_key).params if module_key not in config_params or config_params[module_key] is None: config_params[module_key] = module_params else: patch_module_params(config_params[module_key], module_params, module_key) return config_params class BaseModule: params: Dict = None Loading @@ -47,6 +120,8 @@ class BaseModule: _load_model_keys: set = None def __init__(self, **params) -> None: if self.params is not None and '__param_patched' not in params: params = patch_module_params(params, self.params, self) if params: if self.params is None: self.params = params Loading Loading @@ -228,7 +303,14 @@ TORCH_DTYPE_MAP = { 'bf16': torch.bfloat16, } def load_modules(): MODULE_SCRIPTS = { 'translator': {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'}, 'textdetector': {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'}, 'inpainter': {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'}, 'ocr': {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'}, } def init_module_registries(target_modules=None): def _load_module(module_dir: str, module_pattern: str): modules = os.listdir(module_dir) pattern = re.compile(module_pattern) Loading @@ -243,10 +325,27 @@ def load_modules(): except Exception as e: LOGGER.warning(f'Failed to import {module}: {e}') for kwargs in [ {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'}, {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'}, {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'}, {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'}, ]: _load_module(**kwargs) if target_modules is None: target_modules = MODULE_SCRIPTS if isinstance(target_modules, str): target_modules = [target_modules] for k in target_modules: _load_module(**MODULE_SCRIPTS[k]) def init_textdetector_registries(): init_module_registries('textdetector') def init_inpainter_registries(): init_module_registries('inpainter') def init_ocr_registries(): init_module_registries('ocr') def init_translator_registries(): init_module_registries('translator') scripts/run_module.py 0 → 100644 +66 −0 Original line number Diff line number Diff line import click import sys, os import os.path as osp sys.path.append(osp.dirname(osp.dirname(__file__))) from tqdm import tqdm from utils.config import load_config from utils.shared import PROGRAM_PATH from utils.textblock import visualize_textblocks from utils.proj_imgtrans import ProjImgTrans from utils.config import pcfg from utils.io_utils import imread, imwrite from modules import MODULETYPE_TO_REGISTRIES, init_translator_registries, init_inpainter_registries, init_ocr_registries, init_textdetector_registries os.chdir(PROGRAM_PATH) @click.group() def cli(): """text detector testing scripts. """ def init_module(module_type: str, module_name: str): assert module_type in MODULETYPE_TO_REGISTRIES module_cls = MODULETYPE_TO_REGISTRIES[module_type].get(module_name) module_cls_params = getattr(pcfg.module, module_type + '_params') module_params = module_cls_params.get(module_name, {}) return module_cls(**module_params) @cli.command('run_detector') @click.option('--proj_dir') @click.option('--detector', default=None) @click.option('--config', default='config/config.json') @click.option('--save_dir', default='tmp/test_ctd') def run_detector(proj_dir, detector, config, save_dir): init_textdetector_registries() load_config(config) if detector is None: detector = pcfg.module.textdetector detector = init_module('textdetector', detector) print('detector params:', detector.params) proj = ProjImgTrans(proj_dir) for page_name in tqdm(proj.pages): blk_list = proj.pages[page_name] proj.set_current_img(page_name) mask, blk_list = detector.detect(proj.img_array, blk_list) blk_list = blk_list[:1] print(blk_list[0].get_text()) vis = visualize_textblocks(proj.img_array, blk_list) imwrite(osp.join(save_dir, proj.current_img), vis, ext='.jpg') pass if __name__ == '__main__': cli() ui/module_manager.py +2 −2 Original line number Diff line number Diff line Loading @@ -15,7 +15,7 @@ from modules.translators import MissingTranslatorParams from modules.base import BaseModule, soft_empty_cache from modules import INPAINTERS, TRANSLATORS, TEXTDETECTORS, OCR, \ GET_VALID_TRANSLATORS, GET_VALID_TEXTDETECTORS, GET_VALID_INPAINTERS, GET_VALID_OCR, \ BaseTranslator, InpainterBase, TextDetectorBase, OCRBase BaseTranslator, InpainterBase, TextDetectorBase, OCRBase, merge_config_module_params import modules modules.translators.SYSTEM_LANG = QLocale.system().name() from utils.textblock import TextBlock, sort_regions Loading @@ -24,7 +24,7 @@ from utils.message import create_error_dialog, create_info_dialog from .custom_widget import ImgtransProgressMessageBox, ParamComboBox from .configpanel import ConfigPanel from utils.proj_imgtrans import ProjImgTrans from utils.config import pcfg, merge_config_module_params from utils.config import pcfg cfg_module = pcfg.module Loading Loading
launch.py +2 −2 Original line number Diff line number Diff line Loading @@ -215,9 +215,9 @@ def main(): # import msl.loadlib (required by translators/trans_eztrans) before init QApplication # yield QWindowsContext: OleInitialize() failed on py3.10, from modules.base import load_modules from modules.base import init_module_registries from modules.prepare_local_files import prepare_local_files_forall load_modules() init_module_registries() prepare_local_files_forall() if not args.headless: Loading
modules/__init__.py +10 −1 Original line number Diff line number Diff line Loading @@ -2,11 +2,20 @@ from .ocr import OCR, OCRBase from .textdetector import TEXTDETECTORS, TextDetectorBase from .translators import TRANSLATORS, BaseTranslator from .inpaint import INPAINTERS, InpainterBase from .base import DEFAULT_DEVICE, GPUINTENSIVE_SET from .base import DEFAULT_DEVICE, GPUINTENSIVE_SET, LOGGER, merge_config_module_params, \ init_module_registries, init_textdetector_registries, init_inpainter_registries, init_ocr_registries, init_translator_registries GET_VALID_TEXTDETECTORS = lambda : list(TEXTDETECTORS.module_dict.keys()) GET_VALID_TRANSLATORS = lambda : list(TRANSLATORS.module_dict.keys()) GET_VALID_INPAINTERS = lambda : list(INPAINTERS.module_dict.keys()) GET_VALID_OCR = lambda : list(OCR.module_dict.keys()) MODULETYPE_TO_REGISTRIES = { 'textdetector': TEXTDETECTORS, 'ocr': OCR, 'inpainter': INPAINTERS, 'translator': TRANSLATORS } # TODO: use manga-image-translator as backend... No newline at end of file
modules/base.py +107 −8 Original line number Diff line number Diff line Loading @@ -33,6 +33,79 @@ def register_hooks(hooks_registered: OrderedDict, callbacks: Union[List, Callabl hooks_registered[hk] = callback nhooks += 1 def patch_module_params(cfg_param, module_params, module_name: str = ''): # cfg_param = config_params[module_key] cfg_key_set = set(cfg_param.keys()) module_key_set = set(module_params.keys()) for ck in cfg_key_set: if ck not in module_key_set: LOGGER.warning(f'Found invalid {module_name} config: {ck}') cfg_param.pop(ck) for mk in module_key_set: if mk not in cfg_key_set: if not mk.startswith('__') and mk != 'description': LOGGER.info(f'Found new {module_name} config: {mk}') cfg_param[mk] = module_params[mk] else: mparam = module_params[mk] cparam = cfg_param[mk] if isinstance(mparam, dict): tgt_type = type(mparam['value']) if isinstance(cparam, dict): if 'value' in cparam: v = cparam['value'] elif isinstance(mparam['value'], dict): for k in mparam['value']: if k in cparam: mparam['value'][k] = cparam[k] v = mparam['value'] else: v = mparam['value'] else: v = cparam valid = True if tgt_type != type(v): try: v = tgt_type(v) except: valid = False LOGGER.warning(f'Invalid param value {v} for defined dtype: {tgt_type}, it will be set to default value: {mparam}') if valid: mparam['value'] = v cfg_param[mk] = mparam else: if type(cparam) != type(mparam): if not isinstance(mparam, dict) and isinstance(cparam, dict): cparam = cparam['value'] try: cfg_param[mk] = type(mparam)(cparam) except ValueError: LOGGER.warning(f'Invalid param value {cparam} for defined dtype: {type(mparam)}, it will be set to default value: {mparam}') cfg_param[mk] = mparam cfg_key_list = list(cfg_param.keys()) module_key_list = list(module_params.keys()) if cfg_key_list != module_key_list: new_params = {key: cfg_param[key] for key in module_key_list} cfg_param.clear() cfg_param.update(new_params) module_key_set = set(module_params.keys()) cfg_param['__param_patched'] = True return cfg_param def merge_config_module_params(config_params: Dict, module_keys: List, get_module: Callable) -> Dict: for module_key in module_keys: module_params = get_module(module_key).params if module_key not in config_params or config_params[module_key] is None: config_params[module_key] = module_params else: patch_module_params(config_params[module_key], module_params, module_key) return config_params class BaseModule: params: Dict = None Loading @@ -47,6 +120,8 @@ class BaseModule: _load_model_keys: set = None def __init__(self, **params) -> None: if self.params is not None and '__param_patched' not in params: params = patch_module_params(params, self.params, self) if params: if self.params is None: self.params = params Loading Loading @@ -228,7 +303,14 @@ TORCH_DTYPE_MAP = { 'bf16': torch.bfloat16, } def load_modules(): MODULE_SCRIPTS = { 'translator': {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'}, 'textdetector': {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'}, 'inpainter': {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'}, 'ocr': {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'}, } def init_module_registries(target_modules=None): def _load_module(module_dir: str, module_pattern: str): modules = os.listdir(module_dir) pattern = re.compile(module_pattern) Loading @@ -243,10 +325,27 @@ def load_modules(): except Exception as e: LOGGER.warning(f'Failed to import {module}: {e}') for kwargs in [ {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'}, {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'}, {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'}, {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'}, ]: _load_module(**kwargs) if target_modules is None: target_modules = MODULE_SCRIPTS if isinstance(target_modules, str): target_modules = [target_modules] for k in target_modules: _load_module(**MODULE_SCRIPTS[k]) def init_textdetector_registries(): init_module_registries('textdetector') def init_inpainter_registries(): init_module_registries('inpainter') def init_ocr_registries(): init_module_registries('ocr') def init_translator_registries(): init_module_registries('translator')
scripts/run_module.py 0 → 100644 +66 −0 Original line number Diff line number Diff line import click import sys, os import os.path as osp sys.path.append(osp.dirname(osp.dirname(__file__))) from tqdm import tqdm from utils.config import load_config from utils.shared import PROGRAM_PATH from utils.textblock import visualize_textblocks from utils.proj_imgtrans import ProjImgTrans from utils.config import pcfg from utils.io_utils import imread, imwrite from modules import MODULETYPE_TO_REGISTRIES, init_translator_registries, init_inpainter_registries, init_ocr_registries, init_textdetector_registries os.chdir(PROGRAM_PATH) @click.group() def cli(): """text detector testing scripts. """ def init_module(module_type: str, module_name: str): assert module_type in MODULETYPE_TO_REGISTRIES module_cls = MODULETYPE_TO_REGISTRIES[module_type].get(module_name) module_cls_params = getattr(pcfg.module, module_type + '_params') module_params = module_cls_params.get(module_name, {}) return module_cls(**module_params) @cli.command('run_detector') @click.option('--proj_dir') @click.option('--detector', default=None) @click.option('--config', default='config/config.json') @click.option('--save_dir', default='tmp/test_ctd') def run_detector(proj_dir, detector, config, save_dir): init_textdetector_registries() load_config(config) if detector is None: detector = pcfg.module.textdetector detector = init_module('textdetector', detector) print('detector params:', detector.params) proj = ProjImgTrans(proj_dir) for page_name in tqdm(proj.pages): blk_list = proj.pages[page_name] proj.set_current_img(page_name) mask, blk_list = detector.detect(proj.img_array, blk_list) blk_list = blk_list[:1] print(blk_list[0].get_text()) vis = visualize_textblocks(proj.img_array, blk_list) imwrite(osp.join(save_dir, proj.current_img), vis, ext='.jpg') pass if __name__ == '__main__': cli()
ui/module_manager.py +2 −2 Original line number Diff line number Diff line Loading @@ -15,7 +15,7 @@ from modules.translators import MissingTranslatorParams from modules.base import BaseModule, soft_empty_cache from modules import INPAINTERS, TRANSLATORS, TEXTDETECTORS, OCR, \ GET_VALID_TRANSLATORS, GET_VALID_TEXTDETECTORS, GET_VALID_INPAINTERS, GET_VALID_OCR, \ BaseTranslator, InpainterBase, TextDetectorBase, OCRBase BaseTranslator, InpainterBase, TextDetectorBase, OCRBase, merge_config_module_params import modules modules.translators.SYSTEM_LANG = QLocale.system().name() from utils.textblock import TextBlock, sort_regions Loading @@ -24,7 +24,7 @@ from utils.message import create_error_dialog, create_info_dialog from .custom_widget import ImgtransProgressMessageBox, ParamComboBox from .configpanel import ConfigPanel from utils.proj_imgtrans import ProjImgTrans from utils.config import pcfg, merge_config_module_params from utils.config import pcfg cfg_module = pcfg.module Loading