Commit cb0cc4ce authored by 98440's avatar 98440
Browse files

add intel XPU support

parent c5f36fcc
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -9,6 +9,17 @@ import subprocess
import importlib.util
import pkg_resources
from platform import platform
from pathlib import Path
import sys
import argparse
import os.path as osp
import os
import importlib
import re
import subprocess
import importlib.util
import pkg_resources
from platform import platform

BRANCH = 'dev'
VERSION = '1.4.0'
+10 −3
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from utils.logger import logger as LOGGER
from utils import shared


GPUINTENSIVE_SET = {'cuda', 'mps'}
GPUINTENSIVE_SET = {'cuda', 'xpu', 'mps'}

def register_hooks(hooks_registered: OrderedDict, callbacks: Union[List, Callable, Dict]):
    if callbacks is None:
@@ -162,10 +162,12 @@ import torch

DEFAULT_DEVICE = 'cpu'
if hasattr(torch, 'cuda') and torch.cuda.is_available():
    DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    DEFAULT_DEVICE = 'cuda'
elif hasattr(torch, 'xpu')  and torch.xpu.is_available():
    DEFAULT_DEVICE = 'xpu' if torch.xpu.is_available() else 'cpu'
elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    DEFAULT_DEVICE = 'mps'
BF16_SUPPORTED = DEFAULT_DEVICE == 'cuda' and torch.cuda.is_bf16_supported()
BF16_SUPPORTED = DEFAULT_DEVICE == 'cuda' and torch.cuda.is_bf16_supported() or DEFAULT_DEVICE == 'xpu' and torch.xpu.is_bf16_supported()

def is_nvidia():
    if DEFAULT_DEVICE == 'cuda':
@@ -173,11 +175,15 @@ def is_nvidia():
            return True
    return False


def soft_empty_cache():
    gc.collect()
    if DEFAULT_DEVICE == 'cuda':
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    elif DEFAULT_DEVICE == 'xpu':
       torch.xpu.empty_cache()
       # torch.xpu.ipc_collect()
    elif DEFAULT_DEVICE == 'mps':
        torch.mps.empty_cache()

@@ -187,6 +193,7 @@ DEVICE_SELECTOR = lambda : deepcopy(
        'options': [
            'cpu',
            'cuda',
            'xpu',
            'mps'
        ],
        'value': DEFAULT_DEVICE
+1 −1
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ PyQt6>=6.6.1,<6.7.0 ; python_version > "3.8"
PyQt5-Qt5>=5.15.2 ; python_version <= "3.8"
PyQt5>=5.15.10 ; python_version <= "3.8"
numpy<2
urllib3==1.25.11; sys_platform == 'win32' # https://github.com/psf/requests/issues/5740
urllib3; sys_platform == 'win32' # https://github.com/psf/requests/issues/5740
urllib3; sys_platform == 'darwin' # fix urllib3.package.six.move module not found error
jaconv
torch