Unverified Commit a05e2252 authored by dmMaze's avatar dmMaze Committed by GitHub
Browse files

Merge pull request #729 from minicom365/feature-directml

Add directml support
parents 6bbeca54 770e5163
Loading
Loading
Loading
Loading
+16 −12
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from utils.logger import logger as LOGGER
from utils import shared


GPUINTENSIVE_SET = {'cuda', 'xpu', 'mps'}
GPUINTENSIVE_SET = {'cuda', 'mps', 'xpu', 'privateuseone'}

def register_hooks(hooks_registered: OrderedDict, callbacks: Union[List, Callable, Dict]):
    if callbacks is None:
@@ -156,17 +156,25 @@ class BaseModule:
    def debug_mode(self):
        return shared.DEBUG


os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch
import torch_directml

DEFAULT_DEVICE = 'cpu'
AVAILABLE_DEVICES = ['cpu']
if hasattr(torch, 'cuda') and torch.cuda.is_available():
    DEFAULT_DEVICE = 'cuda'
elif hasattr(torch, 'xpu')  and torch.xpu.is_available():
    AVAILABLE_DEVICES.append(DEFAULT_DEVICE)
if hasattr(torch, 'xpu')  and torch.xpu.is_available():
    DEFAULT_DEVICE = 'xpu' if torch.xpu.is_available() else 'cpu'
elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    AVAILABLE_DEVICES.append(DEFAULT_DEVICE)
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    DEFAULT_DEVICE = 'mps'
    AVAILABLE_DEVICES.append(DEFAULT_DEVICE)
if hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0:
    torch.dml = torch_directml
    DEFAULT_DEVICE = f'privateuseone:{torch.dml.default_device()}'
    AVAILABLE_DEVICES += [f"privateuseone:{d}" for d in range(torch.dml.device_count())]
BF16_SUPPORTED = DEFAULT_DEVICE == 'cuda' and torch.cuda.is_bf16_supported() or DEFAULT_DEVICE == 'xpu' and torch.xpu.is_bf16_supported()

def is_nvidia():
@@ -192,16 +200,12 @@ def soft_empty_cache():
    elif DEFAULT_DEVICE == 'mps':
        torch.mps.empty_cache()

DEVICE_SELECTOR = lambda : deepcopy(

def DEVICE_SELECTOR(not_supported:list[str]=[]): return deepcopy(
    {
        'type': 'selector',
        'options': [
            'cpu',
            'cuda',
            'xpu',
            'mps'
        ],
        'value': DEFAULT_DEVICE
        'options': [opt for opt in AVAILABLE_DEVICES if all(device not in opt for device in not_supported)],
        'value': DEFAULT_DEVICE if not any(DEFAULT_DEVICE in device for device in not_supported) else 'cpu'
    }
)

+4 −4
Original line number Diff line number Diff line
@@ -292,7 +292,7 @@ class LamaInpainterMPE(InpainterBase):
            ], 
            'value': 2048
        },
        'device': DEVICE_SELECTOR()
        'device': DEVICE_SELECTOR(not_supported=['privateuseone'])
    }

    download_file_list = [{
@@ -415,7 +415,7 @@ class LamaLarge(LamaInpainterMPE):
            ], 
            'value': 1536,
        },
        'device': DEVICE_SELECTOR(),
        'device': DEVICE_SELECTOR(not_supported=['privateuseone']),
        'precision': {
            'type': 'selector',
            'options': [
+1 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ mit_params = {
        'options': [8, 16, 24, 32],
        'value': 16
    },
    'device': DEVICE_SELECTOR(),
    'device': DEVICE_SELECTOR(not_supported=['privateuseone']),
    'description': 'OCRMIT32px'
}

+1 −1

File changed.

Contains only whitespace changes.