Loading modules/base.py +26 −3 Original line number Diff line number Diff line Loading @@ -9,7 +9,7 @@ from utils.logger import logger as LOGGER from utils import shared GPUINTENSIVE_SET = {'cuda', 'xpu', 'mps'} GPUINTENSIVE_SET = {'cuda', 'mps', 'xpu', 'privateuseone'} def register_hooks(hooks_registered: OrderedDict, callbacks: Union[List, Callable, Dict]): if callbacks is None: Loading Loading @@ -159,10 +159,18 @@ class BaseModule: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' import torch import torch_directml DEFAULT_DEVICE = 'cpu' if hasattr(torch, 'cuda') and torch.cuda.is_available(): DEFAULT_DEVICE = 'cuda' elif hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0: from modules.dml import directml_init, directml_do_hijack directml_init() directml_do_hijack() for d in range(torch.cuda.device_count()): print(f"device {d}: {torch.cuda.get_device_name(d)}") DEFAULT_DEVICE = 'cpu' elif hasattr(torch, 'xpu') and torch.xpu.is_available(): DEFAULT_DEVICE = 'xpu' if torch.xpu.is_available() else 'cpu' elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): Loading @@ -183,7 +191,7 @@ def is_intel(): def soft_empty_cache(): gc.collect() if DEFAULT_DEVICE == 'cuda': if DEFAULT_DEVICE in ('cuda', 'privateuseone'): torch.cuda.empty_cache() torch.cuda.ipc_collect() elif DEFAULT_DEVICE == 'xpu': Loading @@ -192,7 +200,22 @@ def soft_empty_cache(): elif DEFAULT_DEVICE == 'mps': torch.mps.empty_cache() DEVICE_SELECTOR = lambda : deepcopy( def DEVICE_SELECTOR(): return deepcopy( { 'type': 'selector', 'options': [ 'cpu', 'cuda', 'mps', 'privateuseone' ], 'value': DEFAULT_DEVICE } ) def DEVICE_SELECTOR_NO_DML(): return deepcopy( { 'type': 'selector', 'options': [ Loading modules/inpaint/base.py +5 −5 Original line number Diff line number Diff line Loading @@ -8,7 +8,7 @@ from utils.registry import Registry from utils.textblock_mask import extract_ballon_mask from utils.imgproc_utils import enlarge_window from ..base import BaseModule, DEFAULT_DEVICE, soft_empty_cache, DEVICE_SELECTOR, GPUINTENSIVE_SET, TORCH_DTYPE_MAP, BF16_SUPPORTED from ..base import BaseModule, DEFAULT_DEVICE, soft_empty_cache, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, GPUINTENSIVE_SET, TORCH_DTYPE_MAP, BF16_SUPPORTED from ..textdetector import TextBlock INPAINTERS = Registry('inpainters') Loading Loading @@ -292,7 +292,7 @@ class LamaInpainterMPE(InpainterBase): ], 'value': 2048 }, 'device': DEVICE_SELECTOR() 'device': DEVICE_SELECTOR_NO_DML() } download_file_list = [{ Loading Loading @@ -415,7 +415,7 @@ class LamaLarge(LamaInpainterMPE): ], 'value': 1536, }, 'device': DEVICE_SELECTOR(), 'device': DEVICE_SELECTOR_NO_DML(), 'precision': { 'type': 'selector', 'options': [ Loading modules/ocr/__init__.py +1 −1 Original line number Diff line number Diff line from .base import OCRBase, register_OCR, DEVICE_SELECTOR, DEFAULT_DEVICE, TextBlock, OCR No newline at end of file from .base import OCRBase, register_OCR, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, DEFAULT_DEVICE, TextBlock, OCR modules/ocr/base.py +1 −1 Original line number Diff line number Diff line Loading @@ -8,7 +8,7 @@ from utils.registry import Registry OCR = Registry('OCR') register_OCR = OCR.register_module from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, LOGGER from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, LOGGER class OCRBase(BaseModule): Loading modules/ocr/ocr_mit.py +6 −2 Original line number Diff line number Diff line Loading @@ -2,7 +2,7 @@ from typing import List import numpy as np from copy import deepcopy from .base import DEVICE_SELECTOR, OCRBase, register_OCR, TextBlock from .base import DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, OCRBase, register_OCR, TextBlock from utils.textblock import collect_textblock_regions mit_params = { Loading @@ -15,6 +15,10 @@ mit_params = { 'description': 'OCRMIT32px' } mit_params_no_dml = deepcopy(mit_params) mit_params_no_dml['device'] = DEVICE_SELECTOR_NO_DML() class MITModels(OCRBase): _line_only = True Loading Loading @@ -64,7 +68,7 @@ from .mit48px_ctc import OCR48pxCTC @register_OCR('mit48px_ctc') class OCRMIT48pxCTC(MITModels): params = deepcopy(mit_params) params = deepcopy(mit_params_no_dml) download_file_list = [{ 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr-ctc.zip', 'files': ['ocr-ctc.ckpt', 'alphabet-all-v5.txt'], Loading Loading
modules/base.py +26 −3 Original line number Diff line number Diff line Loading @@ -9,7 +9,7 @@ from utils.logger import logger as LOGGER from utils import shared GPUINTENSIVE_SET = {'cuda', 'xpu', 'mps'} GPUINTENSIVE_SET = {'cuda', 'mps', 'xpu', 'privateuseone'} def register_hooks(hooks_registered: OrderedDict, callbacks: Union[List, Callable, Dict]): if callbacks is None: Loading Loading @@ -159,10 +159,18 @@ class BaseModule: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' import torch import torch_directml DEFAULT_DEVICE = 'cpu' if hasattr(torch, 'cuda') and torch.cuda.is_available(): DEFAULT_DEVICE = 'cuda' elif hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0: from modules.dml import directml_init, directml_do_hijack directml_init() directml_do_hijack() for d in range(torch.cuda.device_count()): print(f"device {d}: {torch.cuda.get_device_name(d)}") DEFAULT_DEVICE = 'cpu' elif hasattr(torch, 'xpu') and torch.xpu.is_available(): DEFAULT_DEVICE = 'xpu' if torch.xpu.is_available() else 'cpu' elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): Loading @@ -183,7 +191,7 @@ def is_intel(): def soft_empty_cache(): gc.collect() if DEFAULT_DEVICE == 'cuda': if DEFAULT_DEVICE in ('cuda', 'privateuseone'): torch.cuda.empty_cache() torch.cuda.ipc_collect() elif DEFAULT_DEVICE == 'xpu': Loading @@ -192,7 +200,22 @@ def soft_empty_cache(): elif DEFAULT_DEVICE == 'mps': torch.mps.empty_cache() DEVICE_SELECTOR = lambda : deepcopy( def DEVICE_SELECTOR(): return deepcopy( { 'type': 'selector', 'options': [ 'cpu', 'cuda', 'mps', 'privateuseone' ], 'value': DEFAULT_DEVICE } ) def DEVICE_SELECTOR_NO_DML(): return deepcopy( { 'type': 'selector', 'options': [ Loading
modules/inpaint/base.py +5 −5 Original line number Diff line number Diff line Loading @@ -8,7 +8,7 @@ from utils.registry import Registry from utils.textblock_mask import extract_ballon_mask from utils.imgproc_utils import enlarge_window from ..base import BaseModule, DEFAULT_DEVICE, soft_empty_cache, DEVICE_SELECTOR, GPUINTENSIVE_SET, TORCH_DTYPE_MAP, BF16_SUPPORTED from ..base import BaseModule, DEFAULT_DEVICE, soft_empty_cache, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, GPUINTENSIVE_SET, TORCH_DTYPE_MAP, BF16_SUPPORTED from ..textdetector import TextBlock INPAINTERS = Registry('inpainters') Loading Loading @@ -292,7 +292,7 @@ class LamaInpainterMPE(InpainterBase): ], 'value': 2048 }, 'device': DEVICE_SELECTOR() 'device': DEVICE_SELECTOR_NO_DML() } download_file_list = [{ Loading Loading @@ -415,7 +415,7 @@ class LamaLarge(LamaInpainterMPE): ], 'value': 1536, }, 'device': DEVICE_SELECTOR(), 'device': DEVICE_SELECTOR_NO_DML(), 'precision': { 'type': 'selector', 'options': [ Loading
modules/ocr/__init__.py +1 −1 Original line number Diff line number Diff line from .base import OCRBase, register_OCR, DEVICE_SELECTOR, DEFAULT_DEVICE, TextBlock, OCR No newline at end of file from .base import OCRBase, register_OCR, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, DEFAULT_DEVICE, TextBlock, OCR
modules/ocr/base.py +1 −1 Original line number Diff line number Diff line Loading @@ -8,7 +8,7 @@ from utils.registry import Registry OCR = Registry('OCR') register_OCR = OCR.register_module from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, LOGGER from ..base import BaseModule, DEFAULT_DEVICE, DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, LOGGER class OCRBase(BaseModule): Loading
modules/ocr/ocr_mit.py +6 −2 Original line number Diff line number Diff line Loading @@ -2,7 +2,7 @@ from typing import List import numpy as np from copy import deepcopy from .base import DEVICE_SELECTOR, OCRBase, register_OCR, TextBlock from .base import DEVICE_SELECTOR, DEVICE_SELECTOR_NO_DML, OCRBase, register_OCR, TextBlock from utils.textblock import collect_textblock_regions mit_params = { Loading @@ -15,6 +15,10 @@ mit_params = { 'description': 'OCRMIT32px' } mit_params_no_dml = deepcopy(mit_params) mit_params_no_dml['device'] = DEVICE_SELECTOR_NO_DML() class MITModels(OCRBase): _line_only = True Loading Loading @@ -64,7 +68,7 @@ from .mit48px_ctc import OCR48pxCTC @register_OCR('mit48px_ctc') class OCRMIT48pxCTC(MITModels): params = deepcopy(mit_params) params = deepcopy(mit_params_no_dml) download_file_list = [{ 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr-ctc.zip', 'files': ['ocr-ctc.ckpt', 'alphabet-all-v5.txt'], Loading