Commit 2cd4eb82 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): support wd14 v3 taggers

parent e1cd4c8c
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -59,6 +59,9 @@ if __name__ == '__main__':
            ('wd14-convnextv2', Wd14Benchmark("ConvNextV2")),
            ('wd14-vit', Wd14Benchmark("ViT")),
            ('wd14-moat', Wd14Benchmark("MOAT")),
            ('wd14-swinv2-v3', Wd14Benchmark("SwinV2_v3")),
            ('wd14-vit-v3', Wd14Benchmark("ViT_v3")),
            ('wd14-convnext-v3', Wd14Benchmark("ConvNext_v3")),
            ('mldanbooru', MLDanbooruBenchmark()),
        ],
        title='Benchmark for Tagging Models',
+0 −2678

File deleted.

Preview size limit exceeded, changes collapsed.

+40 −15
Original line number Diff line number Diff line
@@ -61,17 +61,31 @@ _KAOMOJIS = [
]


def _load_wd14_model(model_repo: str, model_filename: str):
    return open_onnx_model(huggingface_hub.hf_hub_download(model_repo, model_filename))


@lru_cache()
def _get_wd14_model(model_name):
    return _load_wd14_model(MODEL_NAMES[model_name], MODEL_FILENAME)
    """
    Load an ONNX model from the Hugging Face Hub.

    :param model_name: The name of the model.
    :type model_name: str
    :return: The loaded ONNX model.
    :rtype: ONNXModel
    """
    return open_onnx_model(huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], MODEL_FILENAME))


@lru_cache()
def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]:
    """
    Get labels for the WD14 model.

    :param model_name: The name of the model.
    :type model_name: str
    :param no_underline: If True, replaces underscores in tag names with spaces.
    :type no_underline: bool
    :return: A tuple containing the list of tag names, and lists of indexes for rating, general, and character categories.
    :rtype: Tuple[List[str], List[int], List[int], List[int]]
    """
    path = huggingface_hub.hf_hub_download(MODEL_NAMES[model_name], LABEL_FILENAME)
    df = pd.read_csv(path)
    name_series = df["name"]
@@ -131,16 +145,27 @@ def get_wd14_tags(
):
    """
    Overview:
        Tagging image by wd14 v2 model. Similar to
        `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ .

    :param image: Image to tagging.
    :param model_name: Name of the mode, should be one of the \
        ``SwinV2``, ``ConvNext``, ``ConvNextV2``, ``ViT`` or ``MOAT``, default is ``ConvNextV2``.
    :param general_threshold: Threshold for default tags, default is ``0.35``.
    :param character_threshold: Threshold for character tags, default is ``0.85``.
    :param drop_overlap: Drop overlap tags or not, default is ``False``.
    :return: Tagging results for levels, features and characters.
        Get tags for an image with wd14 taggers.
        Similar to `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ .

    :param image: The input image.
    :type image: ImageTyping
    :param model_name: The name of the model to use.
    :type model_name: str
    :param general_threshold: The threshold for general tags.
    :type general_threshold: float
    :param general_mcut_enabled: If True, applies MCut thresholding to general tags.
    :type general_mcut_enabled: bool
    :param character_threshold: The threshold for character tags.
    :type character_threshold: float
    :param character_mcut_enabled: If True, applies MCut thresholding to character tags.
    :type character_mcut_enabled: bool
    :param no_underline: If True, replaces underscores in tag names with spaces.
    :type no_underline: bool
    :param drop_overlap: If True, drops overlapping tags.
    :type drop_overlap: bool
    :return: A tuple containing dictionaries for rating, general, and character tags with their probabilities.
    :rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]

    Example:
        Here are some images for example