Commit 4145012c authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add pydocs for multilabel

parent 5f79bf0c
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -12,5 +12,6 @@ imgutils.generic
    classify
    enhance
    clip
    multilabel_timm
    siglip
    yolo
+23 −0
Original line number Diff line number Diff line
imgutils.generic.multilabel_timm
=======================================

.. currentmodule:: imgutils.generic.multilabel_timm

.. automodule:: imgutils.generic.multilabel_timm



MultiLabelTIMMModel
----------------------------------------------

.. autoclass:: MultiLabelTIMMModel
    :members: __init__,predict,make_ui,launch_demo



multilabel_timm_predict
--------------------------------------------------

.. autofunction:: multilabel_timm_predict

+208 −0
Original line number Diff line number Diff line
"""
Multi-Label TIMM Model Module

This module provides functionality for working with multi-label image classification models
trained with TIMM (PyTorch Image Models) and exported to ONNX format. It includes:

1. The MultiLabelTIMMModel class for loading and making predictions with models hosted on Hugging Face Hub
2. Functions for batch prediction and demo interface creation
3. Support for custom thresholds at both category and tag levels
4. Flexible output formatting options for different use cases

The models are expected to be stored on Hugging Face Hub with specific files:

- model.onnx: The ONNX model file
- selected_tags.csv: CSV file containing tag information and categories
- preprocess.json: JSON configuration for image preprocessing
- thresholds.csv: Optional CSV file with recommended thresholds

This module is designed to work with multi-label classification tasks where images can
belong to multiple categories and have multiple tags within each category.
"""

import io
import json
import os
@@ -45,7 +67,31 @@ FMT_UNSET = object()


class MultiLabelTIMMModel:
    """
    A class for working with multi-label image classification models trained with TIMM.

    This class handles loading models from Hugging Face Hub, preprocessing images,
    and making predictions with customizable thresholds.

    :param repo_id: The Hugging Face Hub repository ID containing the model
    :type repo_id: str
    :param hf_token: Optional Hugging Face authentication token for private repositories
    :type hf_token: Optional[str]
    :param category_names: Optional mapping of category IDs to display names
    :type category_names: Dict[Any, str]
    """

    def __init__(self, repo_id: str, hf_token: Optional[str] = None, category_names: Dict[Any, str] = None):
        """
        Initialize a MultiLabelTIMMModel.

        :param repo_id: The Hugging Face Hub repository ID containing the model
        :type repo_id: str
        :param hf_token: Optional Hugging Face authentication token for private repositories
        :type hf_token: Optional[str]
        :param category_names: Optional mapping of category IDs to display names
        :type category_names: Dict[Any, str]
        """
        self.repo_id = repo_id
        self._model = None
        self._df_tags = None
@@ -68,6 +114,12 @@ class MultiLabelTIMMModel:
        return self._hf_token or os.environ.get('HF_TOKEN')

    def _open_model(self):
        """
        Load the ONNX model from Hugging Face Hub.

        :return: The loaded ONNX model
        :rtype: object
        """
        with self._lock:
            if self._model is None:
                self._model = open_onnx_model(hf_hub_download(
@@ -80,6 +132,12 @@ class MultiLabelTIMMModel:
        return self._model

    def _open_tags(self):
        """
        Load tag information from the Hugging Face Hub.

        :return: DataFrame containing tag information
        :rtype: pandas.DataFrame
        """
        with self._lock:
            if self._df_tags is None:
                self._df_tags = pd.read_csv(hf_hub_download(
@@ -97,6 +155,12 @@ class MultiLabelTIMMModel:
        return self._df_tags

    def _open_preprocess(self):
        """
        Load preprocessing configuration from the Hugging Face Hub.

        :return: A tuple of validation and test preprocessing transforms
        :rtype: tuple
        """
        with self._lock:
            if self._preprocess is None:
                with open(hf_hub_download(
@@ -112,6 +176,12 @@ class MultiLabelTIMMModel:
        return self._preprocess

    def _open_default_category_thresholds(self):
        """
        Load default category thresholds from the Hugging Face Hub.

        :return: Dictionary mapping category IDs to threshold values
        :rtype: dict
        """
        with self._lock:
            if self._default_category_thresholds is None:
                try:
@@ -131,6 +201,18 @@ class MultiLabelTIMMModel:
        return self._default_category_thresholds

    def _raw_predict(self, image: ImageTyping, preprocessor: Literal['test', 'val'] = 'test'):
        """
        Make a raw prediction with the model.

        :param image: The input image
        :type image: ImageTyping
        :param preprocessor: Which preprocessor to use ('test' or 'val')
        :type preprocessor: Literal['test', 'val']

        :return: Dictionary of model outputs
        :rtype: dict
        :raises ValueError: If an unknown preprocessor is specified
        """
        image = load_image(image, force_background='white', mode='RGB')
        model = self._open_model()

@@ -150,6 +232,39 @@ class MultiLabelTIMMModel:
    def predict(self, image: ImageTyping, preprocessor: Literal['test', 'val'] = 'test',
                thresholds: Union[float, Dict[Any, float]] = None, use_tag_thresholds: bool = True,
                fmt=FMT_UNSET):
        """
        Make a prediction and format the results.

        This method processes an image through the model and applies thresholds to determine
        which tags to include in the results. The output format can be customized using the fmt parameter.

        :param image: The input image
        :type image: ImageTyping
        :param preprocessor: Which preprocessor to use ('test' or 'val')
        :type preprocessor: Literal['test', 'val']
        :param thresholds: Threshold values for tag confidence. Can be a single float applied to all categories
                          or a dictionary mapping category IDs or names to threshold values
        :type thresholds: Union[float, Dict[Any, float]]
        :param use_tag_thresholds: Whether to use tag-level thresholds if available
        :type use_tag_thresholds: bool
        :param fmt: Output format specification. Can be a tuple of category names to include,
                   or FMT_UNSET to use all categories
        :type fmt: Any

        :return: Formatted prediction results according to the fmt parameter
        :rtype: Any

        .. note::
            The fmt argument can include the following keys:

            - Category names: dicts containing category-specific tags and their confidences
            - ``tag``: a dict containing all tags across categories and their confidences
            - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization
            - ``logits``: a 1-dim logits result of image.
            - ``prediction``: a 1-dim prediction result of image

            You can extract specific category predictions or all tags based on your needs.
        """
        df_tags = self._open_tags()
        values = self._raw_predict(image, preprocessor=preprocessor)
        prediction = values['prediction']
@@ -200,6 +315,17 @@ class MultiLabelTIMMModel:

    def make_ui(self, default_thresholds: Union[float, Dict[Any, float]] = None,
                default_use_tag_thresholds: bool = True):
        """
        Create a Gradio UI for the model.

        :param default_thresholds: Default threshold values to use in the UI
        :type default_thresholds: Union[float, Dict[Any, float]]
        :param default_use_tag_thresholds: Whether to use tag-level thresholds by default
        :type default_use_tag_thresholds: bool

        :return: None
        :raises EnvironmentError: If Gradio is not installed
        """
        _check_gradio_env()
        df_tags = self._open_tags()
        default_category_thresholds = self._open_default_category_thresholds()
@@ -296,6 +422,23 @@ class MultiLabelTIMMModel:
    def launch_demo(self, default_thresholds: Union[float, Dict[Any, float]] = None,
                    default_use_tag_thresholds: bool = True,
                    server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
        """
        Launch a Gradio demo for the model.

        :param default_thresholds: Default threshold values to use in the demo
        :type default_thresholds: Union[float, Dict[Any, float]]
        :param default_use_tag_thresholds: Whether to use tag-level thresholds by default
        :type default_use_tag_thresholds: bool
        :param server_name: Server name for the Gradio app
        :type server_name: Optional[str]
        :param server_port: Server port for the Gradio app
        :type server_port: Optional[int]
        :param kwargs: Additional keyword arguments to pass to gr.launch()
        :type kwargs: Any

        :return: None
        :raises EnvironmentError: If Gradio is not installed
        """
        _check_gradio_env()
        with gr.Blocks() as demo:
            with gr.Row():
@@ -322,6 +465,19 @@ class MultiLabelTIMMModel:
def _open_models_for_repo_id(repo_id: str, category_names: Optional[Tuple[Tuple[Any, str], ...]] = None,
                             hf_token: Optional[str] = None) \
        -> MultiLabelTIMMModel:
    """
    Open and cache a MultiLabelTIMMModel for a given repository ID.

    :param repo_id: The Hugging Face Hub repository ID
    :type repo_id: str
    :param category_names: Optional tuple of (category_id, name) pairs for category naming
    :type category_names: Optional[Tuple[Tuple[Any, str], ...]]
    :param hf_token: Optional Hugging Face authentication token
    :type hf_token: Optional[str]

    :return: A cached MultiLabelTIMMModel instance
    :rtype: MultiLabelTIMMModel
    """
    return MultiLabelTIMMModel(
        repo_id=repo_id,
        hf_token=hf_token,
@@ -333,6 +489,58 @@ def multilabel_timm_predict(image: ImageTyping, repo_id: str, category_names: Di
                            preprocessor: Literal['test', 'val'] = 'test',
                            thresholds: Union[float, Dict[Any, float]] = None, use_tag_thresholds: bool = True,
                            fmt=FMT_UNSET, hf_token: Optional[str] = None):
    """
    Make predictions using a multi-label TIMM model.

    This function provides a convenient interface for making predictions with models
    hosted on Hugging Face Hub without directly instantiating a MultiLabelTIMMModel.

    :param image: The input image
    :type image: ImageTyping
    :param repo_id: The Hugging Face Hub repository ID containing the model
    :type repo_id: str
    :param category_names: Optional mapping of category IDs to display names
    :type category_names: Dict[Any, str]
    :param preprocessor: Which preprocessor to use ('test' or 'val')
    :type preprocessor: Literal['test', 'val']
    :param thresholds: Threshold values for tag confidence. Can be a single float applied to all categories
                      or a dictionary mapping category IDs or names to threshold values
    :type thresholds: Union[float, Dict[Any, float]]
    :param use_tag_thresholds: Whether to use tag-level thresholds if available
    :type use_tag_thresholds: bool
    :param fmt: Output format specification. Can be a tuple of category names to include,
               or FMT_UNSET to use all categories
    :type fmt: Any
    :param hf_token: Optional Hugging Face authentication token for private repositories
    :type hf_token: Optional[str]

    :return: Formatted prediction results according to the fmt parameter
    :rtype: Any

    .. note::
        The fmt argument can include the following keys:

        - Category names: dicts containing category-specific tags and their confidences
        - ``tag``: a dict containing all tags across categories and their confidences
        - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization
        - ``logits``: a 1-dim logits result of image.
        - ``prediction``: a 1-dim prediction result of image

        You can extract specific category predictions or all tags based on your needs.

        Example:

        >>> # Get all categories and tags
        >>> result = multilabel_timm_predict('image.jpg', 'username/model-name')
        >>>
        >>> # Get only specific categories
        >>> result = multilabel_timm_predict('image.jpg', 'username/model-name',
        ...                                  fmt=('general', 'character'))
        >>>
        >>> # Get just the raw prediction values
        >>> prediction = multilabel_timm_predict('image.jpg', 'username/model-name',
        ...                                     fmt='prediction')
    """
    model = _open_models_for_repo_id(
        repo_id=repo_id,
        category_names=tuple((key, value) for key, value in sorted((category_names or {}).items())),