Commit fc78171c authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add docs

parent 15c0b432
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -20,3 +20,9 @@ convert_wd14_emb_to_prediction



denormalize_wd14_emb
----------------------------------------------

.. autofunction:: denormalize_wd14_emb

+41 −0
Original line number Diff line number Diff line
@@ -457,6 +457,16 @@ def _open_denormalize_model(
        model_name: str = _DEFAULT_MODEL_NAME,
        denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
):
    """
    Open a denormalization model for WD14 embeddings.

    :param model_name: Name of the model.
    :type model_name: str
    :param denormalizer_name: Name of the denormalizer.
    :type denormalizer_name: str
    :return: The loaded ONNX model.
    :rtype: ONNXModel
    """
    return open_onnx_model(hf_hub_download(
        repo_id='deepghs/embedding_aligner',
        repo_type='model',
@@ -469,10 +479,41 @@ def denormalize_wd14_emb(
        model_name: str = _DEFAULT_MODEL_NAME,
        denormalizer_name: str = _DEFAULT_DENORMALIZER_NAME,
) -> np.ndarray:
    """
    Denormalize WD14 embeddings.

    :param emb: The embedding to denormalize.
    :type emb: numpy.ndarray
    :param model_name: Name of the model.
    :type model_name: str
    :param denormalizer_name: Name of the denormalizer.
    :type denormalizer_name: str
    :return: The denormalized embedding.
    :rtype: numpy.ndarray

    Examples:
        >>> import numpy as np
        >>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction, denormalize_wd14_emb
        ...
        >>> embedding, (r, g, c) = get_wd14_tags(
        ...     'image.png',
        ...     fmt=('embedding', ('rating', 'general', 'character')),
        ... )
        ...
        >>> # normalize embedding
        >>> embedding = embedding / np.linalg.norm(embedding)
        ...
        >>> # denormalize this embedding
        >>> output = denormalize_wd14_emb(embedding)
        ...
        >>> # should be similar to r, g, c, approx 1e-3 error
        >>> rating, general, character = convert_wd14_emb_to_prediction(output)
    """
    model = _open_denormalize_model(
        model_name=model_name,
        denormalizer_name=denormalizer_name,
    )
    emb = emb / np.linalg.norm(emb, axis=-1, keepdims=True)
    if len(emb.shape) == 1:
        output, = model.run(output_names=['embedding'], input_feed={'input': emb[None, ...]})
        return output[0]