Commit 1648a51a authored by dmMaze's avatar dmMaze
Browse files

try to fix import issue #794

parent 0295fd04
Loading
Loading
Loading
Loading
+2 −27
Original line number Diff line number Diff line
@@ -4,11 +4,9 @@ import argparse
import os.path as osp
import os
import importlib
import re
import subprocess
import pkg_resources
from platform import platform
import logging

BRANCH = 'dev'
VERSION = '1.4.0'
@@ -109,30 +107,6 @@ def commit_hash():
    return stored_commit_hash


def load_modules():
    LOGGER = logging.getLogger('BallonTranslator')
    def _load_module(module_dir: str, module_pattern: str):
        modules = os.listdir(module_dir)
        pattern = re.compile(module_pattern)
        module_path = module_dir.replace('/', '.')
        if not module_path.endswith('.'):
            module_path += '.'
        for module_name in modules:
            if pattern.match(module_name) is not None:
                try:
                    module = module_path + module_name.replace('.py', '')
                    importlib.import_module(module)
                except Exception as e:
                    LOGGER.warning(f'Failed to import {module}: {e}')

    for kwargs in [
        {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'},
        {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'},
        {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'},
        {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'},
    ]:
        _load_module(**kwargs)

BT = None
APP = None

@@ -229,8 +203,9 @@ def main():

    setup_logging(shared.LOGGING_PATH)

    load_modules()
    from modules.base import load_modules
    from modules.prepare_local_files import prepare_local_files_forall
    load_modules()
    prepare_local_files_forall()

    app_args = sys.argv
+25 −1
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@ import time
from typing import Dict, List, Callable, Union
from copy import deepcopy
from collections import OrderedDict
import re
import importlib

from utils.logger import logger as LOGGER
from utils import shared
@@ -234,3 +236,25 @@ TORCH_DTYPE_MAP = {
    'bf16': torch.bfloat16,
}
    
def load_modules():
    def _load_module(module_dir: str, module_pattern: str):
        modules = os.listdir(module_dir)
        pattern = re.compile(module_pattern)
        module_path = module_dir.replace('/', '.')
        if not module_path.endswith('.'):
            module_path += '.'
        for module_name in modules:
            if pattern.match(module_name) is not None:
                try:
                    module = module_path + module_name.replace('.py', '')
                    importlib.import_module(module)
                except Exception as e:
                    LOGGER.warning(f'Failed to import {module}: {e}')

    for kwargs in [
        {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'},
        {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'},
        {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'},
        {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'},
    ]:
        _load_module(**kwargs)