Commit 2c87f24c authored by dmMaze's avatar dmMaze
Browse files

fix importing torch_directml

parent a05e2252
Loading
Loading
Loading
Loading
+10 −5
Original line number Diff line number Diff line
@@ -158,7 +158,6 @@ class BaseModule:

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch
import torch_directml

DEFAULT_DEVICE = 'cpu'
AVAILABLE_DEVICES = ['cpu']
@@ -171,10 +170,16 @@ if hasattr(torch, 'xpu') and torch.xpu.is_available():
if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    DEFAULT_DEVICE = 'mps'
    AVAILABLE_DEVICES.append(DEFAULT_DEVICE)

try: 
    import torch_directml
    if hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0:
        torch.dml = torch_directml
        DEFAULT_DEVICE = f'privateuseone:{torch.dml.default_device()}'
        AVAILABLE_DEVICES += [f"privateuseone:{d}" for d in range(torch.dml.device_count())]
except:
    # directml is not supported
    pass
BF16_SUPPORTED = DEFAULT_DEVICE == 'cuda' and torch.cuda.is_bf16_supported() or DEFAULT_DEVICE == 'xpu' and torch.xpu.is_bf16_supported()

def is_nvidia():