Commit 24d04e52 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add docs for predict_fmt

parent 3e2f04a2
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ ClassifyModel
-----------------------------------------

.. autoclass:: ClassifyModel
    :members: __init__, predict_score, predict, clear, make_ui, launch_demo
    :members: __init__, predict_score, predict, predict_fmt, clear, make_ui, launch_demo



@@ -29,3 +29,10 @@ classify_predict



classify_predict_fmt
-----------------------------------------

.. autofunction:: classify_predict_fmt


+61 −5
Original line number Diff line number Diff line
@@ -383,6 +383,31 @@ class ClassifyModel:
        return self._open_label(model_name)[label_group][max_id], output[max_id].item()

    def predict_fmt(self, image: ImageTyping, model_name: str, fmt='scores-top5'):
        """
        Predict the scores for each class with given format specification using the specified model.

        :param image: The input image to classify.
        :type image: ImageTyping
        :param model_name: The name of the model to use for prediction.
        :type model_name: str
        :param fmt: Format specification. Default is ``scores-top5``.

        :return: Prediction result formatted with parameter ``fmt``.

        :raises ValueError: If the model name is invalid.
        :raises RuntimeError: If there's an error during prediction.

        .. note::
            The following specifications are supported in parameter ``fmt``:

            - ``output``, raw prediction result, in np.ndarray format.
            - ``logits``, (not available in some models) logits result, in np.ndarray format.
            - ``embedding``, (not available in some models) embeddings result, in np.ndarray format.
            - ``scores``, prediction scores of all classes in dict format.
            - ``scores-topK``, prediction scores of top-K classes in dict format, e.g. ``scores-top10`` means top 10 scores.
            - ``scores-<label_group>``, prediction scores of all classes with label group ``<label_group>``, e.g. ``scores-descriptions`` means all scores with ``descriptions`` label group.
            - ``scores-topK-<label_group>``, prediction scores of top-K classes with label group ``<label_group>``.
        """
        d_data = {name: value[0] for name, value in self._raw_predict(image, model_name).items()}
        scores = d_data['output']
        d_labels = self._open_label(model_name)
@@ -392,11 +417,11 @@ class ClassifyModel:
            matching = re.fullmatch(r'^scores(-top(?P<topk>\d+))?(-(?P<label_group>[a-zA-Z\d_]+))?$', vname)
            if matching:
                topk = int(matching.group('topk')) if matching.group('topk') else None
                group_label = matching.group('label_group') if matching.group('label_group') else 'default'
                vname_to_spair[vname] = (topk, group_label)
                if (topk, group_label) not in d_scores:
                    d_scores[(topk, group_label)] = _labels_scores_to_topk(
                        labels=d_labels[group_label],
                label_group = matching.group('label_group') if matching.group('label_group') else 'default'
                vname_to_spair[vname] = (topk, label_group)
                if (topk, label_group) not in d_scores:
                    d_scores[(topk, label_group)] = _labels_scores_to_topk(
                        labels=d_labels[label_group],
                        scores=scores,
                        topk=topk,
                    )
@@ -620,6 +645,37 @@ def classify_predict(image: ImageTyping, repo_id: str, model_name: str, label_gr

def classify_predict_fmt(image: ImageTyping, repo_id: str, model_name: str, fmt='scores-top5',
                         hf_token: Optional[str] = None):
    """
    Predict the scores for each class with given format specification using the specified model.

    This function is a convenience wrapper around ClassifyModel's predict method.

    :param image: The input image to classify.
    :type image: ImageTyping
    :param repo_id: The repository ID containing the models.
    :type repo_id: str
    :param model_name: The name of the model to use for prediction.
    :type model_name: str
    :param fmt: Format specification. Default is ``scores-top5``.
    :param hf_token: Optional Hugging Face authentication token.
    :type hf_token: Optional[str]

    :return: Prediction result formatted with parameter ``fmt``.

    :raises ValueError: If the model name is invalid.
    :raises RuntimeError: If there's an error during prediction.

    .. note::
        The following specifications are supported in parameter ``fmt``:

        - ``output``, raw prediction result, in np.ndarray format.
        - ``logits``, (not available in some models) logits result, in np.ndarray format.
        - ``embedding``, (not available in some models) embeddings result, in np.ndarray format.
        - ``scores``, prediction scores of all classes in dict format.
        - ``scores-topK``, prediction scores of top-K classes in dict format, e.g. ``scores-top10`` means top 10 scores.
        - ``scores-<label_group>``, prediction scores of all classes with label group ``<label_group>``, e.g. ``scores-descriptions`` means all scores with ``descriptions`` label group.
        - ``scores-topK-<label_group>``, prediction scores of top-K classes with label group ``<label_group>``.
    """
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict_fmt(
        image=image,
        model_name=model_name,