Loading modules/base.py +15 −24 Original line number Diff line number Diff line Loading @@ -156,25 +156,26 @@ class BaseModule: def debug_mode(self): return shared.DEBUG os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' import torch import torch_directml DEFAULT_DEVICE = 'cpu' AVAILABLE_DEVICES = ['cpu'] if hasattr(torch, 'cuda') and torch.cuda.is_available(): DEFAULT_DEVICE = 'cuda' elif hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0: from modules.dml import directml_init, directml_do_hijack directml_init() directml_do_hijack() for d in range(torch.cuda.device_count()): print(f"device {d}: {torch.cuda.get_device_name(d)}") DEFAULT_DEVICE = 'cpu' elif hasattr(torch, 'xpu') and torch.xpu.is_available(): AVAILABLE_DEVICES.append(DEFAULT_DEVICE) if hasattr(torch, 'xpu') and torch.xpu.is_available(): DEFAULT_DEVICE = 'xpu' if torch.xpu.is_available() else 'cpu' elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): AVAILABLE_DEVICES.append(DEFAULT_DEVICE) if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): DEFAULT_DEVICE = 'mps' AVAILABLE_DEVICES.append(DEFAULT_DEVICE) if hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0: torch.dml = torch_directml [f"privateuseone:{d}" for d in range(torch.dml.device_count())] DEFAULT_DEVICE = f'privateuseone:{torch.dml.default_device()}' AVAILABLE_DEVICES += [f"privateuseone:{d}" for d in range(torch.dml.device_count())] BF16_SUPPORTED = DEFAULT_DEVICE == 'cuda' and torch.cuda.is_bf16_supported() or DEFAULT_DEVICE == 'xpu' and torch.xpu.is_bf16_supported() def is_nvidia(): Loading @@ -191,7 +192,7 @@ def is_intel(): def soft_empty_cache(): gc.collect() if DEFAULT_DEVICE in ('cuda', 'privateuseone'): if DEFAULT_DEVICE == 'cuda': torch.cuda.empty_cache() torch.cuda.ipc_collect() elif DEFAULT_DEVICE == 'xpu': Loading @@ -204,12 +205,7 @@ def soft_empty_cache(): def DEVICE_SELECTOR(): return deepcopy( { 'type': 'selector', 'options': [ 'cpu', 'cuda', 'mps', 'privateuseone' ], 'options': AVAILABLE_DEVICES, 'value': DEFAULT_DEVICE } ) Loading @@ -218,13 +214,8 @@ def DEVICE_SELECTOR(): return deepcopy( def DEVICE_SELECTOR_NO_DML(): return deepcopy( { 'type': 'selector', 'options': [ 'cpu', 'cuda', 'xpu', 'mps' ], 'value': DEFAULT_DEVICE 'options': [opt for opt in AVAILABLE_DEVICES if not 'privateuseone' in opt], 'value': DEFAULT_DEVICE if DEFAULT_DEVICE != "privateuseone" else "cpu" } ) Loading modules/ocr/ocr_mit.py +2 −6 Original line number Diff line number Diff line Loading @@ -11,14 +11,10 @@ mit_params = { 'options': [8, 16, 24, 32], 'value': 16 }, 'device': DEVICE_SELECTOR(), 'device': DEVICE_SELECTOR_NO_DML(), 'description': 'OCRMIT32px' } mit_params_no_dml = deepcopy(mit_params) mit_params_no_dml['device'] = DEVICE_SELECTOR_NO_DML() class MITModels(OCRBase): _line_only = True Loading Loading @@ -68,7 +64,7 @@ from .mit48px_ctc import OCR48pxCTC @register_OCR('mit48px_ctc') class OCRMIT48pxCTC(MITModels): params = deepcopy(mit_params_no_dml) params = deepcopy(mit_params) download_file_list = [{ 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr-ctc.zip', 'files': ['ocr-ctc.ckpt', 'alphabet-all-v5.txt'], Loading Loading
modules/base.py +15 −24 Original line number Diff line number Diff line Loading @@ -156,25 +156,26 @@ class BaseModule: def debug_mode(self): return shared.DEBUG os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' import torch import torch_directml DEFAULT_DEVICE = 'cpu' AVAILABLE_DEVICES = ['cpu'] if hasattr(torch, 'cuda') and torch.cuda.is_available(): DEFAULT_DEVICE = 'cuda' elif hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0: from modules.dml import directml_init, directml_do_hijack directml_init() directml_do_hijack() for d in range(torch.cuda.device_count()): print(f"device {d}: {torch.cuda.get_device_name(d)}") DEFAULT_DEVICE = 'cpu' elif hasattr(torch, 'xpu') and torch.xpu.is_available(): AVAILABLE_DEVICES.append(DEFAULT_DEVICE) if hasattr(torch, 'xpu') and torch.xpu.is_available(): DEFAULT_DEVICE = 'xpu' if torch.xpu.is_available() else 'cpu' elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): AVAILABLE_DEVICES.append(DEFAULT_DEVICE) if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): DEFAULT_DEVICE = 'mps' AVAILABLE_DEVICES.append(DEFAULT_DEVICE) if hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0: torch.dml = torch_directml [f"privateuseone:{d}" for d in range(torch.dml.device_count())] DEFAULT_DEVICE = f'privateuseone:{torch.dml.default_device()}' AVAILABLE_DEVICES += [f"privateuseone:{d}" for d in range(torch.dml.device_count())] BF16_SUPPORTED = DEFAULT_DEVICE == 'cuda' and torch.cuda.is_bf16_supported() or DEFAULT_DEVICE == 'xpu' and torch.xpu.is_bf16_supported() def is_nvidia(): Loading @@ -191,7 +192,7 @@ def is_intel(): def soft_empty_cache(): gc.collect() if DEFAULT_DEVICE in ('cuda', 'privateuseone'): if DEFAULT_DEVICE == 'cuda': torch.cuda.empty_cache() torch.cuda.ipc_collect() elif DEFAULT_DEVICE == 'xpu': Loading @@ -204,12 +205,7 @@ def soft_empty_cache(): def DEVICE_SELECTOR(): return deepcopy( { 'type': 'selector', 'options': [ 'cpu', 'cuda', 'mps', 'privateuseone' ], 'options': AVAILABLE_DEVICES, 'value': DEFAULT_DEVICE } ) Loading @@ -218,13 +214,8 @@ def DEVICE_SELECTOR(): return deepcopy( def DEVICE_SELECTOR_NO_DML(): return deepcopy( { 'type': 'selector', 'options': [ 'cpu', 'cuda', 'xpu', 'mps' ], 'value': DEFAULT_DEVICE 'options': [opt for opt in AVAILABLE_DEVICES if not 'privateuseone' in opt], 'value': DEFAULT_DEVICE if DEFAULT_DEVICE != "privateuseone" else "cpu" } ) Loading
modules/ocr/ocr_mit.py +2 −6 Original line number Diff line number Diff line Loading @@ -11,14 +11,10 @@ mit_params = { 'options': [8, 16, 24, 32], 'value': 16 }, 'device': DEVICE_SELECTOR(), 'device': DEVICE_SELECTOR_NO_DML(), 'description': 'OCRMIT32px' } mit_params_no_dml = deepcopy(mit_params) mit_params_no_dml['device'] = DEVICE_SELECTOR_NO_DML() class MITModels(OCRBase): _line_only = True Loading Loading @@ -68,7 +64,7 @@ from .mit48px_ctc import OCR48pxCTC @register_OCR('mit48px_ctc') class OCRMIT48pxCTC(MITModels): params = deepcopy(mit_params_no_dml) params = deepcopy(mit_params) download_file_list = [{ 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr-ctc.zip', 'files': ['ocr-ctc.ckpt', 'alphabet-all-v5.txt'], Loading