Commit 0295fd04 authored by yihuishou's avatar yihuishou
Browse files

🚀 Move code to modules/base.py

parent ab5ca73d
Loading
Loading
Loading
Loading
+0 −20
Original line number Diff line number Diff line
@@ -8,7 +8,6 @@ import re
import subprocess
import pkg_resources
from platform import platform
import torch
import logging

BRANCH = 'dev'
@@ -144,11 +143,6 @@ def restart():
        BT.close()
    os.execv(sys.executable, ['python'] + sys.argv)


def zluda_available(device_name):
    return "[ZLUDA]" in device_name


def main():

    if args.debug:
@@ -164,20 +158,6 @@ def main():
    print(f'Branch: {BRANCH}')
    print(f"Commit hash: {commit}")

    if hasattr(torch, 'cuda'):
        device_name = torch.cuda.get_device_name(0)
        print('Device name: ', device_name)
        print('Cuda is available: ',torch.cuda.is_available())
        print('Cuda version: ', torch.version.cuda)
        print('ZLUDA is available: ', zluda_available(device_name))

        if zluda_available(device_name):
            torch.backends.cudnn.enabled = False
            torch.backends.cuda.enable_flash_sdp(False)
            torch.backends.cuda.enable_math_sdp(True)
            torch.backends.cuda.enable_mem_efficient_sdp(False)
            torch.backends.cuda.enable_cudnn_sdp(False)

    APP_DIR = os.path.dirname(os.path.abspath(__file__))
    os.chdir(APP_DIR)

+14 −0
Original line number Diff line number Diff line
@@ -159,11 +159,25 @@ class BaseModule:
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch

def zluda_available(device_name):
    return "[ZLUDA]" in device_name

DEFAULT_DEVICE = 'cpu'
AVAILABLE_DEVICES = ['cpu']
if hasattr(torch, 'cuda') and torch.cuda.is_available():
    DEFAULT_DEVICE = 'cuda'
    AVAILABLE_DEVICES.append(DEFAULT_DEVICE)
    device_name = torch.cuda.get_device_name(0)
    print('Device name: ', device_name)
    print('Cuda is available: ', torch.cuda.is_available())
    print('Cuda version: ', torch.version.cuda)
    print('ZLUDA is available: ', zluda_available(device_name))
    if zluda_available(device_name):
        torch.backends.cudnn.enabled = False
        torch.backends.cuda.enable_flash_sdp(False)
        torch.backends.cuda.enable_math_sdp(True)
        torch.backends.cuda.enable_mem_efficient_sdp(False)
        torch.backends.cuda.enable_cudnn_sdp(False)
if hasattr(torch, 'xpu')  and torch.xpu.is_available():
    DEFAULT_DEVICE = 'xpu' if torch.xpu.is_available() else 'cpu'
    AVAILABLE_DEVICES.append(DEFAULT_DEVICE)