Commit 85869739 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add camie pydocs

parent 19255f11
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -13,3 +13,9 @@ get_camie_tags



convert_camie_emb_to_prediction
----------------------------------------------------

.. autofunction:: convert_camie_emb_to_prediction

+54 −1
Original line number Diff line number Diff line
@@ -150,7 +150,7 @@ def _postprocess_embedding_values(
    :type embedding: numpy.ndarray
    :param model_name: Name of the model used
    :type model_name: str
    :param mode: Prediction mode for threshold selection
    :param mode: Prediction mode affecting threshold values
    :type mode: CamieModeTyping
    :param thresholds: Custom thresholds for tag selection
    :type thresholds: Optional[Union[float, Dict[str, float]]]
@@ -232,6 +232,15 @@ def get_camie_tags(
    :return: Extracted tags and embeddings, follow the format from ``fmt``.
    :rtype: Any

    .. note::
        Modes for selection:

        - ``balanced``: Balanced precision/recall
        - ``high_precision``: Higher precision thresholds
        - ``high_recall``: Higher recall thresholds
        - ``micro_opt``: Micro-optimized thresholds
        - ``macro_opt``: Macro-optimized thresholds

    .. note::
        The fmt argument can include the following keys:

@@ -349,6 +358,16 @@ def get_camie_tags(

@ts_lru_cache()
def _get_camie_emb_to_pred_model(model_name: str, is_refined: bool = False):
    """
    Load embedding-to-prediction conversion model.

    :param model_name: Model variant name
    :type model_name: str
    :param is_refined: Use refined embeddings (True) or initial embeddings (False)
    :type is_refined: bool
    :return: ONNX model session for embedding conversion
    :rtype: onnxruntime.InferenceSession
    """
    return open_onnx_model(hf_hub_download(
        repo_id=_REPO_ID,
        repo_type='model',
@@ -366,6 +385,40 @@ def convert_camie_emb_to_prediction(
        drop_overlap: bool = False,
        fmt: Any = ('rating', 'general', 'character'),
):
    """
    Convert stored embeddings back to tag predictions.

    Useful for reprocessing existing embeddings with new thresholds or formats.

    :param emb: Embedding vector(s) from previous inference
    :type emb: np.ndarray
    :param model_name: Original model variant name
    :type model_name: str
    :param is_refined: Whether embeddings come from refined stage, otherwise from initial stage
    :type is_refined: bool
    :param mode: Threshold selection strategy
    :type mode: CamieModeTyping
    :param thresholds: Custom threshold values
    :type thresholds: Optional[Union[float, Dict[str, float]]]
    :param no_underline: Remove underscores from tag names
    :type no_underline: bool
    :param drop_overlap: Remove overlapping tags in general category
    :type drop_overlap: bool
    :param fmt: Output format specification
    :type fmt: Any
    :return: Formatted results matching original prediction format
    :rtype: Any

    .. note::
        Modes for selection:

        - ``balanced``: Balanced precision/recall
        - ``high_precision``: Higher precision thresholds
        - ``high_recall``: Higher recall thresholds
        - ``micro_opt``: Micro-optimized thresholds
        - ``macro_opt``: Macro-optimized thresholds

    """
    model = _get_camie_emb_to_pred_model(model_name=model_name, is_refined=is_refined)
    if len(emb.shape) == 1:
        logits, pred = model.run(["logits", "output"], {'embedding': emb[np.newaxis]})