Commit 5ea39d4c authored by minicom's avatar minicom
Browse files

Merge commit '93aae371' into feature-directml

# Conflicts:
#	modules/base.py
#	translate/ko_KR.qm
#	translate/ko_KR.ts
parents c6dd94d0 93aae371
Loading
Loading
Loading
Loading
+26 −3
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from utils.logger import logger as LOGGER
from utils import shared


GPUINTENSIVE_SET = {'cuda', 'xpu', 'mps'}
GPUINTENSIVE_SET = {'cuda', 'mps', 'xpu', 'privateuseone'}

def register_hooks(hooks_registered: OrderedDict, callbacks: Union[List, Callable, Dict]):
    if callbacks is None:
@@ -159,10 +159,18 @@ class BaseModule:

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch
import torch_directml

DEFAULT_DEVICE = 'cpu'
if hasattr(torch, 'cuda') and torch.cuda.is_available():
    DEFAULT_DEVICE = 'cuda'
elif hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0:
    from modules.dml import directml_init, directml_do_hijack
    directml_init()
    directml_do_hijack()
    for d in range(torch.cuda.device_count()):
        print(f"device {d}: {torch.cuda.get_device_name(d)}")
    DEFAULT_DEVICE = 'cpu'
elif hasattr(torch, 'xpu')  and torch.xpu.is_available():
    DEFAULT_DEVICE = 'xpu' if torch.xpu.is_available() else 'cpu'
elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
@@ -183,7 +191,7 @@ def is_intel():

def soft_empty_cache():
    gc.collect()
    if DEFAULT_DEVICE == 'cuda':
    if DEFAULT_DEVICE in ('cuda', 'privateuseone'):
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    elif DEFAULT_DEVICE == 'xpu':
@@ -192,7 +200,22 @@ def soft_empty_cache():
    elif DEFAULT_DEVICE == 'mps':
        torch.mps.empty_cache()

DEVICE_SELECTOR = lambda : deepcopy(

def DEVICE_SELECTOR(): return deepcopy(
    {
        'type': 'selector',
        'options': [
            'cpu',
            'cuda',
            'mps',
            'privateuseone'
        ],
        'value': DEFAULT_DEVICE
    }
)


def DEVICE_SELECTOR_NO_DML(): return deepcopy(
    {
        'type': 'selector',
        'options': [
+5 −5
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from utils.registry import Registry
from utils.textblock_mask import extract_ballon_mask
from utils.imgproc_utils import enlarge_window

from ..base import BaseModule, DEFAULT_DEVICE, soft_empty_cache, DEVICE_SELECTOR, GPUINTENSIVE_SET, TORCH_DTYPE_MAP, BF16_SUPPORTED
from ..base import BaseModule, DEFAULT_DEVICE, soft_empty_cache, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, GPUINTENSIVE_SET, TORCH_DTYPE_MAP, BF16_SUPPORTED
from ..textdetector import TextBlock

INPAINTERS = Registry('inpainters')
@@ -292,7 +292,7 @@ class LamaInpainterMPE(InpainterBase):
            ], 
            'value': 2048
        },
        'device': DEVICE_SELECTOR()
        'device': DEVICE_SELECTOR_NO_DML()
    }

    download_file_list = [{
@@ -415,7 +415,7 @@ class LamaLarge(LamaInpainterMPE):
            ], 
            'value': 1536,
        },
        'device': DEVICE_SELECTOR(),
        'device': DEVICE_SELECTOR_NO_DML(),
        'precision': {
            'type': 'selector',
            'options': [
+1 −1
Original line number Diff line number Diff line
from .base import OCRBase, register_OCR, DEVICE_SELECTOR, DEFAULT_DEVICE, TextBlock, OCR
 No newline at end of file
from .base import OCRBase, register_OCR, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, DEFAULT_DEVICE, TextBlock, OCR
+1 −1
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from utils.registry import Registry
OCR = Registry('OCR')
register_OCR = OCR.register_module

from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, LOGGER
from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, LOGGER

class OCRBase(BaseModule):

+6 −2
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ from typing import List
import numpy as np
from copy import deepcopy

from .base import DEVICE_SELECTOR, OCRBase, register_OCR, TextBlock
from .base import DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, OCRBase, register_OCR, TextBlock
from utils.textblock import collect_textblock_regions

mit_params = {
@@ -15,6 +15,10 @@ mit_params = {
    'description': 'OCRMIT32px'
}

mit_params_no_dml = deepcopy(mit_params)
mit_params_no_dml['device'] = DEVICE_SELECTOR_NO_DML()


class MITModels(OCRBase):

    _line_only = True
@@ -64,7 +68,7 @@ from .mit48px_ctc import OCR48pxCTC
@register_OCR('mit48px_ctc')
class OCRMIT48pxCTC(MITModels):

    params = deepcopy(mit_params)
    params = deepcopy(mit_params_no_dml)
    download_file_list = [{
        'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr-ctc.zip',
        'files': ['ocr-ctc.ckpt', 'alphabet-all-v5.txt'],
Loading