Commit c8a4d7fa authored by yihuishou's avatar yihuishou
Browse files

🐛 Fix When using AMD GPU, CUDA 11. x is required

parent d23da43c
Loading
Loading
Loading
Loading
+21 −1
Original line number Diff line number Diff line
@@ -289,6 +289,20 @@ def main():
        ballontrans.resetStyleSheet()
    sys.exit(app.exec())

def is_amd_gpu():
    try:
        if sys.platform == 'win32':
            # Windows: use wmic
            cmd = 'wmic path win32_VideoController get name'
            output = subprocess.check_output(cmd, shell=True, text=True, stderr=subprocess.DEVNULL)
            return any(keyword in output for keyword in ["AMD", "Radeon"])

        else:
            return False

    except Exception:
        return False

def prepare_environment():

    try:
@@ -311,6 +325,12 @@ def prepare_environment():
            if not check_reqs([req]):
                run_pip(f"install {req}", req)
                req_updated = True

    if is_amd_gpu():
        # AMD GPU: Cuda 11.8, Pytorch 2.2.2
        print('AMD GPU: Yes')
        torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118 --disable-pip-version-check")
    else:
        torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 --disable-pip-version-check")
    if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
        run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)