Unverified Commit cb8ed0b9 authored by dmMaze's avatar dmMaze Committed by GitHub
Browse files

Merge pull request #794 from yihuishou/dev

Add ZLUDA 3.9.0 config to Support AMD
parents 1a6e28eb 21d81245
Loading
Loading
Loading
Loading
+31 −1
Original line number Diff line number Diff line
@@ -170,7 +170,37 @@ Sugoi 翻译器作者: [mingshiba](https://www.patreon.com/mingshiba)
如需添加新的翻译器请参考[加别的翻译器](doc/加别的翻译器.md),本程序添加新翻译器只需要继承基类实现两个接口即可不需要理会代码其他部分,欢迎大佬提 pr

## 杂
* 电脑带N卡或 Apple silicon 默认启用 GPU 加速
* 电脑带 Nvidia 显卡或 Apple silicon 默认启用 GPU 加速
* 感谢 [bropines](https://github.com/bropines) 提供俄语翻译
* 第三方输入法可能会造成右侧编辑框显示 bug,见[#76](https://github.com/dmMaze/BallonsTranslator/issues/76),暂时不打算修
* 选中文本迷你菜单支持*聚合词典专业划词翻译*[沙拉查词](https://saladict.crimx.com): [安装说明](doc/saladict_chs.md)
* 启用 AMD(ROCm6)显卡加速步骤
   * 更新显卡驱动至最新版(建议 24.12.1 及以上)
   * 下载并安装 [AMD HIP SDK 6.2](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html)
   * 下载 [ZLUDA](https://github.com/lshqqytiger/ZLUDA/releases)(ROCm6版本)并解压到 zluda 文件夹内
   * 复制 zluda 文件夹到系统盘下:比如c盘(C:\zluda)
   * 配置系统环境变量
  
      这里以 windows 10 系统为例:设置 - 系统属性 - 高级系统设置 - 环境变量 - 系统变量 - 找到 path 变量
  
      点击编辑 在最后添加 `C:\zluda``%HIP_PATH_62%bin` 两项
  
   * 替换 CUDA 库的动态链接文件
  
`C:\zluda` 文件夹内的 `cublas64_11.dll` `cusparse64_11.dll``nvrtc64_112_0.dll` 复制出一份到桌面

      按如下规则重命名复制出来的文件

      `原文件名` → `新文件名`

      `cublas.dll` → `cublas64_11.dll`

      `cusparse.dll` → `cusparse64_11.dll`

      `nvrtc.dll` → `nvrtc64_112_0.dll`
    
      将已经重命名的文件替换掉 `BallonsTranslator\ballontrans_pylibs_win\Lib\site-packages\torch\lib\` 目录中的同名文件

    * 启动程序并设置 OCR 和文本检测 为 Cuda **(图像修复请继续使用 CPU)**
    * 运行 OCR 并等待 ZLUDA 编译 PTX 文件 **(首次编译大概需要 5-10 分钟,取决于 CPU 性能)**
    * **下次运行无需编译**
+9 −32
Original line number Diff line number Diff line
@@ -4,11 +4,9 @@ import argparse
import os.path as osp
import os
import importlib
import re
import subprocess
import pkg_resources
from platform import platform
import logging

BRANCH = 'dev'
VERSION = '1.4.0'
@@ -109,30 +107,6 @@ def commit_hash():
    return stored_commit_hash


def load_modules():
    LOGGER = logging.getLogger('BallonTranslator')
    def _load_module(module_dir: str, module_pattern: str):
        modules = os.listdir(module_dir)
        pattern = re.compile(module_pattern)
        module_path = module_dir.replace('/', '.')
        if not module_path.endswith('.'):
            module_path += '.'
        for module_name in modules:
            if pattern.match(module_name) is not None:
                try:
                    module = module_path + module_name.replace('.py', '')
                    importlib.import_module(module)
                except Exception as e:
                    LOGGER.warning(f'Failed to import {module}: {e}')

    for kwargs in [
        {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'},
        {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'},
        {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'},
        {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'},
    ]:
        _load_module(**kwargs)

BT = None
APP = None

@@ -143,7 +117,6 @@ def restart():
        BT.close()
    os.execv(sys.executable, ['python'] + sys.argv)


def main():

    if args.debug:
@@ -153,10 +126,10 @@ def main():

    commit = commit_hash()

    print('py version: ', sys.version)
    print('py executable: ', sys.executable)
    print(f'version: {VERSION}')
    print(f'branch: {BRANCH}')
    print('Python version: ', sys.version)
    print('Python executable: ', sys.executable)
    print(f'Version: {VERSION}')
    print(f'Branch: {BRANCH}')
    print(f"Commit hash: {commit}")

    APP_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -164,6 +137,9 @@ def main():

    prepare_environment()

    from utils.zluda_config import enable_zluda_config
    enable_zluda_config()

    if args.update:
        if getattr(sys, 'frozen', False):
            print('Running as app, skipping update.')
@@ -230,8 +206,9 @@ def main():

    setup_logging(shared.LOGGING_PATH)

    load_modules()
    from modules.base import load_modules
    from modules.prepare_local_files import prepare_local_files_forall
    load_modules()
    prepare_local_files_forall()

    app_args = sys.argv
+25 −1
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@ import time
from typing import Dict, List, Callable, Union
from copy import deepcopy
from collections import OrderedDict
import re
import importlib

from utils.logger import logger as LOGGER
from utils import shared
@@ -220,3 +222,25 @@ TORCH_DTYPE_MAP = {
    'bf16': torch.bfloat16,
}
    
def load_modules():
    def _load_module(module_dir: str, module_pattern: str):
        modules = os.listdir(module_dir)
        pattern = re.compile(module_pattern)
        module_path = module_dir.replace('/', '.')
        if not module_path.endswith('.'):
            module_path += '.'
        for module_name in modules:
            if pattern.match(module_name) is not None:
                try:
                    module = module_path + module_name.replace('.py', '')
                    importlib.import_module(module)
                except Exception as e:
                    LOGGER.warning(f'Failed to import {module}: {e}')

    for kwargs in [
        {'module_dir': 'modules/translators', 'module_pattern': r'trans_(.*?).py'},
        {'module_dir': 'modules/textdetector', 'module_pattern': r'detector_(.*?).py'},
        {'module_dir': 'modules/inpaint', 'module_pattern': r'inpaint_(.*?).py'},
        {'module_dir': 'modules/ocr', 'module_pattern': r'ocr_(.*?).py'},
    ]:
        _load_module(**kwargs)

utils/zluda_config.py

0 → 100644
+21 −0
Original line number Diff line number Diff line
import torch

# 检测是否包含 ZLUDA 标记
def zluda_available(device_name):
    return "[ZLUDA]" in device_name

# 关闭 ZLUDA Cudnn 支持 防止错误
def enable_zluda_config():
    if hasattr(torch, 'cuda') and torch.cuda.is_available():
        device_name = torch.cuda.get_device_name(0)
        print('Device name: ', device_name)
        print('Cuda is available: ', torch.cuda.is_available())
        print('Cuda version: ', torch.version.cuda)
        print('ZLUDA is available: ', zluda_available(device_name))

        if zluda_available(device_name):
            torch.backends.cudnn.enabled = False
            torch.backends.cuda.enable_flash_sdp(False)
            torch.backends.cuda.enable_math_sdp(True)
            torch.backends.cuda.enable_mem_efficient_sdp(False)
            torch.backends.cuda.enable_cudnn_sdp(False)