Unverified Commit c7a5c158 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #42 from deepghs/dev/onnx

dev(narugo): use both cpu and cuda provider when creating onnx session
parents da7d5910 3e06de44
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -63,14 +63,18 @@ def get_onnx_provider(provider: Optional[str] = None):
                         f'but unsupported provider {provider!r} found.')


def _open_onnx_model(ckpt: str, provider: str) -> InferenceSession:
def _open_onnx_model(ckpt: str, provider: str, use_cpu: bool = True) -> InferenceSession:
    options = SessionOptions()
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
    if provider == "CPUExecutionProvider":
        options.intra_op_num_threads = os.cpu_count()

    providers = [provider]
    if use_cpu and "CPUExecutionProvider" not in providers:
        providers.append("CPUExecutionProvider")

    logging.info(f'Model {ckpt!r} loaded with provider {provider!r}')
    return InferenceSession(ckpt, options, [provider])
    return InferenceSession(ckpt, options, providers=providers)


def open_onnx_model(ckpt: str, mode: str = None) -> InferenceSession: