Commit 68c16fd9 authored by minicom's avatar minicom
Browse files

Some fixes and changed the device list to only show available devices.

parent 5ea39d4c
Loading
Loading
Loading
Loading
+15 −24
Original line number Diff line number Diff line
@@ -156,25 +156,26 @@ class BaseModule:
    def debug_mode(self):
        return shared.DEBUG


os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch
import torch_directml

DEFAULT_DEVICE = 'cpu'
AVAILABLE_DEVICES = ['cpu']
if hasattr(torch, 'cuda') and torch.cuda.is_available():
    DEFAULT_DEVICE = 'cuda'
elif hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0:
    from modules.dml import directml_init, directml_do_hijack
    directml_init()
    directml_do_hijack()
    for d in range(torch.cuda.device_count()):
        print(f"device {d}: {torch.cuda.get_device_name(d)}")
    DEFAULT_DEVICE = 'cpu'
elif hasattr(torch, 'xpu')  and torch.xpu.is_available():
    AVAILABLE_DEVICES.append(DEFAULT_DEVICE)
if hasattr(torch, 'xpu')  and torch.xpu.is_available():
    DEFAULT_DEVICE = 'xpu' if torch.xpu.is_available() else 'cpu'
elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    AVAILABLE_DEVICES.append(DEFAULT_DEVICE)
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    DEFAULT_DEVICE = 'mps'
    AVAILABLE_DEVICES.append(DEFAULT_DEVICE)
if hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0:
    torch.dml = torch_directml
    [f"privateuseone:{d}" for d in range(torch.dml.device_count())]
    DEFAULT_DEVICE = f'privateuseone:{torch.dml.default_device()}'
    AVAILABLE_DEVICES += [f"privateuseone:{d}" for d in range(torch.dml.device_count())]
BF16_SUPPORTED = DEFAULT_DEVICE == 'cuda' and torch.cuda.is_bf16_supported() or DEFAULT_DEVICE == 'xpu' and torch.xpu.is_bf16_supported()

def is_nvidia():
@@ -191,7 +192,7 @@ def is_intel():

def soft_empty_cache():
    gc.collect()
    if DEFAULT_DEVICE in ('cuda', 'privateuseone'):
    if DEFAULT_DEVICE == 'cuda':
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    elif DEFAULT_DEVICE == 'xpu':
@@ -204,12 +205,7 @@ def soft_empty_cache():
def DEVICE_SELECTOR(): return deepcopy(
    {
        'type': 'selector',
        'options': [
            'cpu',
            'cuda',
            'mps',
            'privateuseone'
        ],
        'options': AVAILABLE_DEVICES,
        'value': DEFAULT_DEVICE
    }
)
@@ -218,13 +214,8 @@ def DEVICE_SELECTOR(): return deepcopy(
def DEVICE_SELECTOR_NO_DML(): return deepcopy(
    {
        'type': 'selector',
        'options': [
            'cpu',
            'cuda',
            'xpu',
            'mps'
        ],
        'value': DEFAULT_DEVICE
        'options': [opt for opt in AVAILABLE_DEVICES if not 'privateuseone' in opt],
        'value': DEFAULT_DEVICE if DEFAULT_DEVICE != "privateuseone" else "cpu"
    }
)

+2 −6
Original line number Diff line number Diff line
@@ -11,14 +11,10 @@ mit_params = {
        'options': [8, 16, 24, 32],
        'value': 16
    },
    'device': DEVICE_SELECTOR(),
    'device': DEVICE_SELECTOR_NO_DML(),
    'description': 'OCRMIT32px'
}

mit_params_no_dml = deepcopy(mit_params)
mit_params_no_dml['device'] = DEVICE_SELECTOR_NO_DML()


class MITModels(OCRBase):

    _line_only = True
@@ -68,7 +64,7 @@ from .mit48px_ctc import OCR48pxCTC
@register_OCR('mit48px_ctc')
class OCRMIT48pxCTC(MITModels):

    params = deepcopy(mit_params_no_dml)
    params = deepcopy(mit_params)
    download_file_list = [{
        'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr-ctc.zip',
        'files': ['ocr-ctc.ckpt', 'alphabet-all-v5.txt'],