Commit 45ddfcbc authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add device id

parent 4d617dcf
Loading
Loading
Loading
Loading
+15 −4
Original line number Diff line number Diff line
@@ -63,12 +63,18 @@ def get_onnx_provider(provider: Optional[str] = None):
                         f'but unsupported provider {provider!r} found.')


def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True) -> InferenceSession:
def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True,
                     cuda_device_id: Optional[int] = None) -> InferenceSession:
    options = SessionOptions()
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
    if provider == "CPUExecutionProvider":
        options.intra_op_num_threads = os.cpu_count()

    if provider == 'CUDAExecutionProvider' and cuda_device_id is not None:
        providers = [
            ('CUDAExecutionProvider', {'device_id': cuda_device_id}),
        ]
    else:
        providers = [provider]
    if use_cpu and "CPUExecutionProvider" not in providers:
        providers.append("CPUExecutionProvider")
@@ -77,7 +83,7 @@ def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True) -> Inferenc
    return InferenceSession(ckpt, options, providers=providers)


def open_onnx_model(ckpt: str, mode: str = None) -> InferenceSession:
def open_onnx_model(ckpt: str, mode: str = None, cuda_device_id: Optional[int] = None) -> InferenceSession:
    """
    Overview:
        Open an ONNX model and load its ONNX runtime.
@@ -93,4 +99,9 @@ def open_onnx_model(ckpt: str, mode: str = None) -> InferenceSession:
        on Linux, executing ``export ONNX_MODE=cpu`` will ignore any existing CUDA and force the model inference
        to run on CPU.
    """
    return _open_onnx_model(ckpt, get_onnx_provider(mode or os.environ.get('ONNX_MODE', None)))
    return _open_onnx_model(
        ckpt=ckpt,
        provider=get_onnx_provider(mode or os.environ.get('ONNX_MODE', None)),
        use_cpu=True,
        cuda_device_id=cuda_device_id,
    )