Unverified Commit fd01e3a2 authored by Sergey Pinus's avatar Sergey Pinus Committed by GitHub
Browse files

Merge branch 'dmMaze:dev' into dev

parents dcd3414f 6eec0cdc
Loading
Loading
Loading
Loading
+32 −1
Original line number Diff line number Diff line
@@ -137,6 +137,7 @@ Sugoi 翻译器作者: [mingshiba](https://www.patreon.com/mingshiba)
 * 暂时仅支持日文(方块字都差不多)和英文检测,训练代码和说明见https://github.com/dmMaze/comic-text-detector
 * 支持使用 [星河云(团子漫画OCR)](https://cloud.stariver.org.cn/)的文本检测,需要填写用户名和密码,每次启动时会自动登录。
   * 详细说明见 [团子OCR说明](doc/团子OCR说明.md)
 * `YSGDetector` 是由 [lhj5426](https://github.com/lhj5426) 训练的模型,能更好地过滤日漫/CG里的拟声词。需要手动从 [YSGYoloDetector](https://huggingface.co/dreMaz/YSGYoloDetector) 下载模型放到 data/models 目录下。


### OCR
@@ -170,7 +171,37 @@ Sugoi 翻译器作者: [mingshiba](https://www.patreon.com/mingshiba)
如需添加新的翻译器请参考[加别的翻译器](doc/加别的翻译器.md),本程序添加新翻译器只需要继承基类实现两个接口即可不需要理会代码其他部分,欢迎大佬提 pr

## 杂
* 电脑带N卡或 Apple silicon 默认启用 GPU 加速
* 电脑带 Nvidia 显卡或 Apple silicon 默认启用 GPU 加速
* 感谢 [bropines](https://github.com/bropines) 提供俄语翻译
* 第三方输入法可能会造成右侧编辑框显示 bug,见[#76](https://github.com/dmMaze/BallonsTranslator/issues/76),暂时不打算修
* 选中文本迷你菜单支持*聚合词典专业划词翻译*[沙拉查词](https://saladict.crimx.com): [安装说明](doc/saladict_chs.md)
* 启用 AMD(ROCm6)显卡加速步骤
   * 更新显卡驱动至最新版(建议 24.12.1 及以上)
   * 下载并安装 [AMD HIP SDK 6.2](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html)
   * 下载 [ZLUDA](https://github.com/lshqqytiger/ZLUDA/releases)(ROCm6版本)并解压到 zluda 文件夹内
   * 复制 zluda 文件夹到系统盘下:比如c盘(C:\zluda)
   * 配置系统环境变量
  
      这里以 windows 10 系统为例:设置 - 系统属性 - 高级系统设置 - 环境变量 - 系统变量 - 找到 path 变量
  
      点击编辑 在最后添加 `C:\zluda``%HIP_PATH_62%bin` 两项
  
   * 替换 CUDA 库的动态链接文件
  
`C:\zluda` 文件夹内的 `cublas64_11.dll` `cusparse64_11.dll``nvrtc64_112_0.dll` 复制出一份到桌面

      按如下规则重命名复制出来的文件

      `原文件名` → `新文件名`

      `cublas.dll` → `cublas64_11.dll`

      `cusparse.dll` → `cusparse64_11.dll`

      `nvrtc.dll` → `nvrtc64_112_0.dll`
    
      将已经重命名的文件替换掉 `BallonsTranslator\ballontrans_pylibs_win\Lib\site-packages\torch\lib\` 目录中的同名文件

    * 启动程序并设置 OCR 和文本检测 为 Cuda **(图像修复请继续使用 CPU)**
    * 运行 OCR 并等待 ZLUDA 编译 PTX 文件 **(首次编译大概需要 5-10 分钟,取决于 CPU 性能)**
    * **下次运行无需编译**
+5 −1
Original line number Diff line number Diff line
@@ -139,6 +139,10 @@ This project is heavily dependent upon [manga-image-translator](https://github.c
 * Support using text detection from [Starriver Cloud (Tuanzi Manga OCR)](https://cloud.stariver.org.cn/). Username and password need to be filled in, and automatic login will be performed each time the program is launched.

   * For detailed instructions, see **Tuanzi OCR Instructions**: ([Chinese](doc/团子OCR说明.md) & [Brazilian Portuguese](doc/Manual_TuanziOCR_pt-BR.md) only)
 
 * `YSGDetector` models are trained by [lhj5426](https://github.com/lhj5426), these models would filter out onomatopoeia in CGs/Manga, download checkpoints from [YSGYoloDetector](https://huggingface.co/dreMaz/YSGYoloDetector) and put into `data/models`. 


## OCR
 * All mit* models are from manga-image-translator, support English, Japanese and Korean recognition and text color extraction.
 * [manga_ocr](https://github.com/kha-white/manga-ocr) is from [kha-white](https://github.com/kha-white), text recognition for Japanese, with the main focus being Japanese manga.
+9 −32
Original line number Diff line number Diff line
@@ -4,11 +4,9 @@ import argparse
import os.path as osp
import os
import importlib
import re
import subprocess
import pkg_resources
from platform import platform
import logging

BRANCH = 'dev'
VERSION = '1.4.0'
@@ -109,30 +107,6 @@ def commit_hash():
    return stored_commit_hash


def load_modules():
    LOGGER = logging.getLogger('BallonTranslator')
    def _load_module(module_dir: str, module_pattern: str):
        modules = os.listdir(module_dir)
        pattern = re.compile(module_pattern)
        module_path = module_dir.replace('/', '.')
        if not module_path.endswith('.'):
            module_path += '.'
        for module_name in modules:
            if pattern.match(module_name) is not None:
                try:
                    module = module_path + module_name.replace('.py', '')
                    importlib.import_module(module)
                except Exception as e:
                    LOGGER.warning(f'Failed to import {module}: {e}')

    for kwargs in [
        {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'},
        {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'},
        {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'},
        {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'},
    ]:
        _load_module(**kwargs)

BT = None
APP = None

@@ -143,7 +117,6 @@ def restart():
        BT.close()
    os.execv(sys.executable, ['python'] + sys.argv)


def main():

    if args.debug:
@@ -153,10 +126,10 @@ def main():

    commit = commit_hash()

    print('py version: ', sys.version)
    print('py executable: ', sys.executable)
    print(f'version: {VERSION}')
    print(f'branch: {BRANCH}')
    print('Python version: ', sys.version)
    print('Python executable: ', sys.executable)
    print(f'Version: {VERSION}')
    print(f'Branch: {BRANCH}')
    print(f"Commit hash: {commit}")

    APP_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -164,6 +137,9 @@ def main():

    prepare_environment()

    from utils.zluda_config import enable_zluda_config
    enable_zluda_config()

    if args.update:
        if getattr(sys, 'frozen', False):
            print('Running as app, skipping update.')
@@ -230,8 +206,9 @@ def main():

    setup_logging(shared.LOGGING_PATH)

    load_modules()
    from modules.base import load_modules
    from modules.prepare_local_files import prepare_local_files_forall
    load_modules()
    prepare_local_files_forall()

    app_args = sys.argv
+30 −1
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@ import time
from typing import Dict, List, Callable, Union
from copy import deepcopy
from collections import OrderedDict
import re
import importlib

from utils.logger import logger as LOGGER
from utils import shared
@@ -124,6 +126,8 @@ class BaseModule:
                if hasattr(self, k):
                    model = getattr(self, k)
                    if model is not None:
                        if hasattr(model, 'unload_model'):
                            model.unload_model(empty_cache=False)
                        del model
                        setattr(self, k, None)
                        model_deleted = True
@@ -156,6 +160,9 @@ class BaseModule:
    def debug_mode(self):
        return shared.DEBUG
    
    def flush(self, param_key: str):
        return None

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch

@@ -220,3 +227,25 @@ TORCH_DTYPE_MAP = {
    'bf16': torch.bfloat16,
}
    
def load_modules():
    def _load_module(module_dir: str, module_pattern: str):
        modules = os.listdir(module_dir)
        pattern = re.compile(module_pattern)
        module_path = module_dir.replace('/', '.')
        if not module_path.endswith('.'):
            module_path += '.'
        for module_name in modules:
            if pattern.match(module_name) is not None:
                try:
                    module = module_path + module_name.replace('.py', '')
                    importlib.import_module(module)
                except Exception as e:
                    LOGGER.warning(f'Failed to import {module}: {e}')

    for kwargs in [
        {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'},
        {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'},
        {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'},
        {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'},
    ]:
        _load_module(**kwargs)
+8 −3
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ from typing import Union, List, Tuple
from collections import OrderedDict

from utils.textblock import TextBlock
from utils.proj_imgtrans import ProjImgTrans

from utils.registry import Registry
TEXTDETECTORS = Registry('textdetectors')
@@ -26,16 +27,20 @@ class TextDetectorBase(BaseModule):
                self.name = key
                break

    def _detect(self, *args, **kwargs) -> Tuple[np.ndarray, List[TextBlock]]:
    def _detect(self, img: np.ndarray, proj: ProjImgTrans) -> Tuple[np.ndarray, List[TextBlock]]:
        '''
        The proj context can be accessed via ```proj```
        '''
        raise NotImplementedError

    def setup_detector(self):
        raise NotImplementedError

    def detect(self, img: np.ndarray) -> Tuple[np.ndarray, List[TextBlock]]:
    def detect(self, img: np.ndarray, proj: ProjImgTrans = None) -> Tuple[np.ndarray, List[TextBlock]]:
        # TODO: allow processing proj entirely in _detect and yield progress
        if not self.all_model_loaded():
            self.load_model()
        mask, blk_list = self._detect(img)
        mask, blk_list = self._detect(img, proj)
        for blk in blk_list:
            blk.det_model = self.name
        return mask, blk_list
Loading