Commit a2a984ec authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add docs for this submodule

parent 45ddfcbc
Loading
Loading
Loading
Loading
+87 −11
Original line number Diff line number Diff line
"""
Overview:
    Management of onnx models.
    Management of ONNX models with automatic runtime detection and provider selection.

    This module provides utilities for loading and managing ONNX models with support for
    different execution providers (CPU, CUDA, TensorRT). It automatically handles the
    installation of onnxruntime based on the system configuration and provides a
    convenient interface for model inference.
"""
import logging
import os
import shutil
import warnings
from typing import Optional

from hbutils.system import pip_install
@@ -15,6 +21,14 @@ __all__ = [


def _ensure_onnxruntime():
    """
    Ensure that onnxruntime is installed on the system.

    This function automatically detects if NVIDIA GPU is available and installs
    the appropriate version of onnxruntime (GPU or CPU version).

    :raises ImportError: If installation fails
    """
    try:
        import onnxruntime
    except (ImportError, ModuleNotFoundError):
@@ -39,13 +53,35 @@ alias = {

def get_onnx_provider(provider: Optional[str] = None):
    """
    Overview:
        Get onnx provider.
    Get the appropriate ONNX execution provider based on system capabilities and user preference.

    This function automatically detects available execution providers and returns the most
    suitable one. It supports aliases for common providers and falls back to CPU execution
    if GPU providers are not available.

    :param provider: The provider for ONNX runtime. ``None`` by default and will automatically detect
        if the ``CUDAExecutionProvider`` is available. If it is available, it will be used,
        otherwise the default ``CPUExecutionProvider`` will be used.
    :return: String of the provider.
        otherwise the default ``CPUExecutionProvider`` will be used. Supported aliases include
        'gpu' for CUDAExecutionProvider and 'trt' for TensorrtExecutionProvider.
    :type provider: Optional[str]

    :return: String name of the selected execution provider.
    :rtype: str

    :raises ValueError: If the specified provider is not supported or available.

    Example::
        >>> # Auto-detect provider
        >>> provider = get_onnx_provider()
        >>> print(provider)  # 'CUDAExecutionProvider' or 'CPUExecutionProvider'

        >>> # Explicitly request GPU provider
        >>> provider = get_onnx_provider('gpu')
        >>> print(provider)  # 'CUDAExecutionProvider'

        >>> # Request CPU provider
        >>> provider = get_onnx_provider('cpu')
        >>> print(provider)  # 'CPUExecutionProvider'
    """
    if not provider:
        if "CUDAExecutionProvider" in get_available_providers():
@@ -65,6 +101,24 @@ def get_onnx_provider(provider: Optional[str] = None):

def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True,
                     cuda_device_id: Optional[int] = None) -> InferenceSession:
    """
    Internal function to create and configure an ONNX inference session.

    This function handles the low-level configuration of the ONNX runtime session,
    including optimization settings and provider-specific configurations.

    :param ckpt: Path to the ONNX model file.
    :type ckpt: str
    :param provider: Name of the execution provider to use.
    :type provider: str
    :param use_cpu: Whether to include CPU provider as fallback. Defaults to True.
    :type use_cpu: bool
    :param cuda_device_id: Specific CUDA device ID to use for GPU inference.
    :type cuda_device_id: Optional[int]

    :return: Configured ONNX inference session.
    :rtype: InferenceSession
    """
    options = SessionOptions()
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
    if provider == "CPUExecutionProvider":
@@ -75,6 +129,9 @@ def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True,
            ('CUDAExecutionProvider', {'device_id': cuda_device_id}),
        ]
    else:
        if provider != 'CUDAExecutionProvider' and cuda_device_id is not None:
            warnings.warn(UserWarning(
                'CUDA device ID specified but provider is not CUDAExecutionProvider. The device ID will be ignored.'))
        providers = [provider]
    if use_cpu and "CPUExecutionProvider" not in providers:
        providers.append("CPUExecutionProvider")
@@ -85,19 +142,38 @@ def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True,

def open_onnx_model(ckpt: str, mode: str = None, cuda_device_id: Optional[int] = None) -> InferenceSession:
    """
    Overview:
        Open an ONNX model and load its ONNX runtime.
    Open an ONNX model and create a configured inference session.

    :param ckpt: ONNX model file.
    :param mode: Provider of the ONNX. Default is ``None`` which means the provider will be auto-detected,
        see :func:`get_onnx_provider` for more details.
    :return: A loaded ONNX runtime object.
    This function provides a high-level interface for loading ONNX models with
    automatic provider selection and optimization. It supports environment variable
    configuration for runtime provider selection.

    :param ckpt: Path to the ONNX model file to load.
    :type ckpt: str
    :param mode: Provider of the ONNX runtime. Default is ``None`` which means the provider will be auto-detected,
        see :func:`get_onnx_provider` for more details. Can also be controlled via ONNX_MODE environment variable.
    :type mode: Optional[str]
    :param cuda_device_id: Specific CUDA device ID to use for GPU inference. Only effective when using CUDA provider.
    :type cuda_device_id: Optional[int]

    :return: A loaded and configured ONNX inference session ready for prediction.
    :rtype: InferenceSession

    .. note::
        When ``mode`` is set to ``None``, it will attempt to detect the environment variable ``ONNX_MODE``.
        This means you can decide which ONNX runtime to use by setting the environment variable. For example,
        on Linux, executing ``export ONNX_MODE=cpu`` will ignore any existing CUDA and force the model inference
        to run on CPU.

    Example::
        >>> # Load model with auto-detected provider
        >>> session = open_onnx_model('model.onnx')

        >>> # Force CPU execution
        >>> session = open_onnx_model('model.onnx', mode='cpu')

        >>> # Use specific CUDA device
        >>> session = open_onnx_model('model.onnx', mode='gpu', cuda_device_id=1)
    """
    return _open_onnx_model(
        ckpt=ckpt,