Loading modules/base.py +10 −5 Original line number Diff line number Diff line Loading @@ -158,7 +158,6 @@ class BaseModule: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' import torch import torch_directml DEFAULT_DEVICE = 'cpu' AVAILABLE_DEVICES = ['cpu'] Loading @@ -171,10 +170,16 @@ if hasattr(torch, 'xpu') and torch.xpu.is_available(): if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): DEFAULT_DEVICE = 'mps' AVAILABLE_DEVICES.append(DEFAULT_DEVICE) try: import torch_directml if hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0: torch.dml = torch_directml DEFAULT_DEVICE = f'privateuseone:{torch.dml.default_device()}' AVAILABLE_DEVICES += [f"privateuseone:{d}" for d in range(torch.dml.device_count())] except: # directml is not supported pass BF16_SUPPORTED = DEFAULT_DEVICE == 'cuda' and torch.cuda.is_bf16_supported() or DEFAULT_DEVICE == 'xpu' and torch.xpu.is_bf16_supported() def is_nvidia(): Loading Loading
modules/base.py +10 −5 Original line number Diff line number Diff line Loading @@ -158,7 +158,6 @@ class BaseModule: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' import torch import torch_directml DEFAULT_DEVICE = 'cpu' AVAILABLE_DEVICES = ['cpu'] Loading @@ -171,10 +170,16 @@ if hasattr(torch, 'xpu') and torch.xpu.is_available(): if hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): DEFAULT_DEVICE = 'mps' AVAILABLE_DEVICES.append(DEFAULT_DEVICE) try: import torch_directml if hasattr(torch, 'privateuseone') and torch_directml.device_count() > 0: torch.dml = torch_directml DEFAULT_DEVICE = f'privateuseone:{torch.dml.default_device()}' AVAILABLE_DEVICES += [f"privateuseone:{d}" for d in range(torch.dml.device_count())] except: # directml is not supported pass BF16_SUPPORTED = DEFAULT_DEVICE == 'cuda' and torch.cuda.is_bf16_supported() or DEFAULT_DEVICE == 'xpu' and torch.xpu.is_bf16_supported() def is_nvidia(): Loading