Loading imgutils/tagging/wd14.py +12 −0 Original line number Diff line number Diff line Loading @@ -8,8 +8,10 @@ from typing import List, Tuple import huggingface_hub import numpy as np import onnxruntime import pandas as pd from PIL import Image from hbutils.testing.requires.version import VersionInfo from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping Loading @@ -26,6 +28,8 @@ VIT_V3_MODEL_REPO = 'SmilingWolf/wd-vit-tagger-v3' MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" _IS_V3_SUPPORT = VersionInfo(onnxruntime.__version__) >= '1.17' MODEL_NAMES = { "SwinV2": SWIN_MODEL_REPO, "ConvNext": CONV_MODEL_REPO, Loading @@ -38,6 +42,13 @@ MODEL_NAMES = { "ViT_v3": VIT_V3_MODEL_REPO, } def _version_support_check(model_name): if model_name.endswith('_v3') and not _IS_V3_SUPPORT: raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, ' f'please upgrade it to 1.17+ version.') # pragma: no cover _KAOMOJIS = [ "0_0", "(o)_(o)", Loading Loading @@ -71,6 +82,7 @@ def _get_wd14_model(model_name): :return: The loaded ONNX model. :rtype: ONNXModel """ _version_support_check(model_name) return open_onnx_model(huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], MODEL_FILENAME)) Loading Loading
imgutils/tagging/wd14.py +12 −0 Original line number Diff line number Diff line Loading @@ -8,8 +8,10 @@ from typing import List, Tuple import huggingface_hub import numpy as np import onnxruntime import pandas as pd from PIL import Image from hbutils.testing.requires.version import VersionInfo from .overlap import drop_overlap_tags from ..data import load_image, ImageTyping Loading @@ -26,6 +28,8 @@ VIT_V3_MODEL_REPO = 'SmilingWolf/wd-vit-tagger-v3' MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" _IS_V3_SUPPORT = VersionInfo(onnxruntime.__version__) >= '1.17' MODEL_NAMES = { "SwinV2": SWIN_MODEL_REPO, "ConvNext": CONV_MODEL_REPO, Loading @@ -38,6 +42,13 @@ MODEL_NAMES = { "ViT_v3": VIT_V3_MODEL_REPO, } def _version_support_check(model_name): if model_name.endswith('_v3') and not _IS_V3_SUPPORT: raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, ' f'please upgrade it to 1.17+ version.') # pragma: no cover _KAOMOJIS = [ "0_0", "(o)_(o)", Loading Loading @@ -71,6 +82,7 @@ def _get_wd14_model(model_name): :return: The loaded ONNX model. :rtype: ONNXModel """ _version_support_check(model_name) return open_onnx_model(huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], MODEL_FILENAME)) Loading