Commit e89757b0 authored by yihuishou's avatar yihuishou
Browse files

🍳 Move code to utils/zluda_config.py

parent c40783d1
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -137,6 +137,9 @@ def main():

    prepare_environment()

    from utils.zluda_config import enable_zluda_config
    enable_zluda_config()

    if args.update:
        if getattr(sys, 'frozen', False):
            print('Running as app, skipping update.')
+0 −11
Original line number Diff line number Diff line
@@ -169,17 +169,6 @@ AVAILABLE_DEVICES = ['cpu']
if hasattr(torch, 'cuda') and torch.cuda.is_available():
    DEFAULT_DEVICE = 'cuda'
    AVAILABLE_DEVICES.append(DEFAULT_DEVICE)
    device_name = torch.cuda.get_device_name(0)
    print('Device name: ', device_name)
    print('Cuda is available: ', torch.cuda.is_available())
    print('Cuda version: ', torch.version.cuda)
    print('ZLUDA is available: ', zluda_available(device_name))
    if zluda_available(device_name):
        torch.backends.cudnn.enabled = False
        torch.backends.cuda.enable_flash_sdp(False)
        torch.backends.cuda.enable_math_sdp(True)
        torch.backends.cuda.enable_mem_efficient_sdp(False)
        torch.backends.cuda.enable_cudnn_sdp(False)
if hasattr(torch, 'xpu')  and torch.xpu.is_available():
    DEFAULT_DEVICE = 'xpu' if torch.xpu.is_available() else 'cpu'
    AVAILABLE_DEVICES.append(DEFAULT_DEVICE)

utils/zluda_config.py

0 → 100644
+21 −0
Original line number Diff line number Diff line
import torch

# 检测是否包含 ZLUDA 标记
def zluda_available(device_name):
    return "[ZLUDA]" in device_name

# 关闭 ZLUDA Cudnn 支持 防止错误
def enable_zluda_config():
    if hasattr(torch, 'cuda') and torch.cuda.is_available():
        device_name = torch.cuda.get_device_name(0)
        print('Device name: ', device_name)
        print('Cuda is available: ', torch.cuda.is_available())
        print('Cuda version: ', torch.version.cuda)
        print('ZLUDA is available: ', zluda_available(device_name))

        if zluda_available(device_name):
            torch.backends.cudnn.enabled = False
            torch.backends.cuda.enable_flash_sdp(False)
            torch.backends.cuda.enable_math_sdp(True)
            torch.backends.cuda.enable_mem_efficient_sdp(False)
            torch.backends.cuda.enable_cudnn_sdp(False)