Commit ad5f3932 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add version check for onnxruntime

parent 2cd4eb82
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -8,8 +8,10 @@ from typing import List, Tuple

import huggingface_hub
import numpy as np
import onnxruntime
import pandas as pd
from PIL import Image
from hbutils.testing.requires.version import VersionInfo

from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
@@ -26,6 +28,8 @@ VIT_V3_MODEL_REPO = 'SmilingWolf/wd-vit-tagger-v3'
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"

_IS_V3_SUPPORT = VersionInfo(onnxruntime.__version__) >= '1.17'

MODEL_NAMES = {
    "SwinV2": SWIN_MODEL_REPO,
    "ConvNext": CONV_MODEL_REPO,
@@ -38,6 +42,13 @@ MODEL_NAMES = {
    "ViT_v3": VIT_V3_MODEL_REPO,
}


def _version_support_check(model_name):
    if model_name.endswith('_v3') and not _IS_V3_SUPPORT:
        raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, '
                               f'please upgrade it to 1.17+ version.')  # pragma: no cover


_KAOMOJIS = [
    "0_0",
    "(o)_(o)",
@@ -71,6 +82,7 @@ def _get_wd14_model(model_name):
    :return: The loaded ONNX model.
    :rtype: ONNXModel
    """
    _version_support_check(model_name)
    return open_onnx_model(huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], MODEL_FILENAME))