Commit 256758af authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add all docs

parent fc78171c
Loading
Loading
Loading
Loading
+32 −12
Original line number Diff line number Diff line
@@ -376,40 +376,60 @@ def convert_wd14_emb_to_prediction(
        denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
):
    """
    Convert WD14 embedding to understandable prediction result.
    Convert WD14 embedding to understandable prediction result. This function can process both
    single embeddings (1-dimensional array) and batches of embeddings (2-dimensional array).

    :param emb: The 1-dim extracted embedding.
    :param emb: The extracted embedding(s). Can be either a 1-dim array for single image or
                2-dim array for batch processing
    :type emb: numpy.ndarray
    :param model_name: The name of the model to use.
    :param model_name: Name of the WD14 model to use for prediction
    :type model_name: str
    :param general_threshold: The threshold for general tags.
    :param general_threshold: Confidence threshold for general tags (0.0 to 1.0)
    :type general_threshold: float
    :param general_mcut_enabled: If True, applies MCut thresholding to general tags.
    :param general_mcut_enabled: Enable MCut thresholding for general tags to improve prediction quality
    :type general_mcut_enabled: bool
    :param character_threshold: The threshold for character tags.
    :param character_threshold: Confidence threshold for character tags (0.0 to 1.0)
    :type character_threshold: float
    :param character_mcut_enabled: If True, applies MCut thresholding to character tags.
    :param character_mcut_enabled: Enable MCut thresholding for character tags to improve prediction quality
    :type character_mcut_enabled: bool
    :param no_underline: If True, replaces underscores in tag names with spaces.
    :param no_underline: Replace underscores with spaces in tag names for better readability
    :type no_underline: bool
    :param drop_overlap: If True, drops overlapping tags.
    :param drop_overlap: Remove overlapping tags to reduce redundancy
    :type drop_overlap: bool
    :param fmt: Return format, default is ``('rating', 'general', 'character')``.
    :return: Prediction result based on the provided fmt.
    :param fmt: Specify return format structure for predictions, default is ``('rating', 'general', 'character')``.
    :type fmt: tuple
    :param denormalize: Whether to denormalize the embedding before prediction
    :type denormalize: bool
    :param denormalizer_name: Name of the denormalizer to use if denormalization is enabled
    :type denormalizer_name: str
    :return: For single embeddings: prediction result based on fmt. For batches: list of prediction results.

    .. note::
        Only the embeddings not get normalized can be converted to understandable prediction result.
        If normalized embeddings are provided, set ``denormalize=True`` to convert them back.

    For batch processing (2-dim input), returns a list where each element corresponds
    to one embedding's predictions in the same format as single embedding output.

    Example:
        >>> import os
        >>> import numpy as np
        >>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction
        >>>
        >>> # extract the feature embedding
        >>> # extract the feature embedding, shape: (W, )
        >>> embedding = get_wd14_tags('skadi.jpg', fmt='embedding')
        >>>
        >>> # convert to understandable result
        >>> rating, general, character = convert_wd14_emb_to_prediction(embedding)
        >>> # these 3 dicts will be the same as that returned by `get_wd14_tags('skadi.jpg')`
        >>>
        >>> # Batch processing, shape: (B, W)
        >>> embeddings = np.stack([
        ...     get_wd14_tags('img1.jpg', fmt='embedding'),
        ...     get_wd14_tags('img2.jpg', fmt='embedding'),
        ... ])
        >>> # results will be a list of (rating, general, character) tuples
        >>> results = convert_wd14_emb_to_prediction(embeddings)
    """
    if denormalize:
        emb = denormalize_wd14_emb(