Loading docs/source/api_doc/generic/index.rst +1 −0 Original line number Diff line number Diff line Loading @@ -12,5 +12,6 @@ imgutils.generic classify enhance clip multilabel_timm siglip yolo docs/source/api_doc/generic/multilabel_timm.rst 0 → 100644 +23 −0 Original line number Diff line number Diff line imgutils.generic.multilabel_timm ======================================= .. currentmodule:: imgutils.generic.multilabel_timm .. automodule:: imgutils.generic.multilabel_timm MultiLabelTIMMModel ---------------------------------------------- .. autoclass:: MultiLabelTIMMModel :members: __init__,predict,make_ui,launch_demo multilabel_timm_predict -------------------------------------------------- .. autofunction:: multilabel_timm_predict imgutils/generic/multilabel_timm.py +208 −0 Original line number Diff line number Diff line """ Multi-Label TIMM Model Module This module provides functionality for working with multi-label image classification models trained with TIMM (PyTorch Image Models) and exported to ONNX format. It includes: 1. The MultiLabelTIMMModel class for loading and making predictions with models hosted on Hugging Face Hub 2. Functions for batch prediction and demo interface creation 3. Support for custom thresholds at both category and tag levels 4. Flexible output formatting options for different use cases The models are expected to be stored on Hugging Face Hub with specific files: - model.onnx: The ONNX model file - selected_tags.csv: CSV file containing tag information and categories - preprocess.json: JSON configuration for image preprocessing - thresholds.csv: Optional CSV file with recommended thresholds This module is designed to work with multi-label classification tasks where images can belong to multiple categories and have multiple tags within each category. """ import io import json import os Loading Loading @@ -45,7 +67,31 @@ FMT_UNSET = object() class MultiLabelTIMMModel: """ A class for working with multi-label image classification models trained with TIMM. This class handles loading models from Hugging Face Hub, preprocessing images, and making predictions with customizable thresholds. :param repo_id: The Hugging Face Hub repository ID containing the model :type repo_id: str :param hf_token: Optional Hugging Face authentication token for private repositories :type hf_token: Optional[str] :param category_names: Optional mapping of category IDs to display names :type category_names: Dict[Any, str] """ def __init__(self, repo_id: str, hf_token: Optional[str] = None, category_names: Dict[Any, str] = None): """ Initialize a MultiLabelTIMMModel. :param repo_id: The Hugging Face Hub repository ID containing the model :type repo_id: str :param hf_token: Optional Hugging Face authentication token for private repositories :type hf_token: Optional[str] :param category_names: Optional mapping of category IDs to display names :type category_names: Dict[Any, str] """ self.repo_id = repo_id self._model = None self._df_tags = None Loading @@ -68,6 +114,12 @@ class MultiLabelTIMMModel: return self._hf_token or os.environ.get('HF_TOKEN') def _open_model(self): """ Load the ONNX model from Hugging Face Hub. :return: The loaded ONNX model :rtype: object """ with self._lock: if self._model is None: self._model = open_onnx_model(hf_hub_download( Loading @@ -80,6 +132,12 @@ class MultiLabelTIMMModel: return self._model def _open_tags(self): """ Load tag information from the Hugging Face Hub. :return: DataFrame containing tag information :rtype: pandas.DataFrame """ with self._lock: if self._df_tags is None: self._df_tags = pd.read_csv(hf_hub_download( Loading @@ -97,6 +155,12 @@ class MultiLabelTIMMModel: return self._df_tags def _open_preprocess(self): """ Load preprocessing configuration from the Hugging Face Hub. :return: A tuple of validation and test preprocessing transforms :rtype: tuple """ with self._lock: if self._preprocess is None: with open(hf_hub_download( Loading @@ -112,6 +176,12 @@ class MultiLabelTIMMModel: return self._preprocess def _open_default_category_thresholds(self): """ Load default category thresholds from the Hugging Face Hub. :return: Dictionary mapping category IDs to threshold values :rtype: dict """ with self._lock: if self._default_category_thresholds is None: try: Loading @@ -131,6 +201,18 @@ class MultiLabelTIMMModel: return self._default_category_thresholds def _raw_predict(self, image: ImageTyping, preprocessor: Literal['test', 'val'] = 'test'): """ Make a raw prediction with the model. :param image: The input image :type image: ImageTyping :param preprocessor: Which preprocessor to use ('test' or 'val') :type preprocessor: Literal['test', 'val'] :return: Dictionary of model outputs :rtype: dict :raises ValueError: If an unknown preprocessor is specified """ image = load_image(image, force_background='white', mode='RGB') model = self._open_model() Loading @@ -150,6 +232,39 @@ class MultiLabelTIMMModel: def predict(self, image: ImageTyping, preprocessor: Literal['test', 'val'] = 'test', thresholds: Union[float, Dict[Any, float]] = None, use_tag_thresholds: bool = True, fmt=FMT_UNSET): """ Make a prediction and format the results. This method processes an image through the model and applies thresholds to determine which tags to include in the results. The output format can be customized using the fmt parameter. :param image: The input image :type image: ImageTyping :param preprocessor: Which preprocessor to use ('test' or 'val') :type preprocessor: Literal['test', 'val'] :param thresholds: Threshold values for tag confidence. Can be a single float applied to all categories or a dictionary mapping category IDs or names to threshold values :type thresholds: Union[float, Dict[Any, float]] :param use_tag_thresholds: Whether to use tag-level thresholds if available :type use_tag_thresholds: bool :param fmt: Output format specification. Can be a tuple of category names to include, or FMT_UNSET to use all categories :type fmt: Any :return: Formatted prediction results according to the fmt parameter :rtype: Any .. note:: The fmt argument can include the following keys: - Category names: dicts containing category-specific tags and their confidences - ``tag``: a dict containing all tags across categories and their confidences - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization - ``logits``: a 1-dim logits result of image. - ``prediction``: a 1-dim prediction result of image You can extract specific category predictions or all tags based on your needs. """ df_tags = self._open_tags() values = self._raw_predict(image, preprocessor=preprocessor) prediction = values['prediction'] Loading Loading @@ -200,6 +315,17 @@ class MultiLabelTIMMModel: def make_ui(self, default_thresholds: Union[float, Dict[Any, float]] = None, default_use_tag_thresholds: bool = True): """ Create a Gradio UI for the model. :param default_thresholds: Default threshold values to use in the UI :type default_thresholds: Union[float, Dict[Any, float]] :param default_use_tag_thresholds: Whether to use tag-level thresholds by default :type default_use_tag_thresholds: bool :return: None :raises EnvironmentError: If Gradio is not installed """ _check_gradio_env() df_tags = self._open_tags() default_category_thresholds = self._open_default_category_thresholds() Loading Loading @@ -296,6 +422,23 @@ class MultiLabelTIMMModel: def launch_demo(self, default_thresholds: Union[float, Dict[Any, float]] = None, default_use_tag_thresholds: bool = True, server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs): """ Launch a Gradio demo for the model. :param default_thresholds: Default threshold values to use in the demo :type default_thresholds: Union[float, Dict[Any, float]] :param default_use_tag_thresholds: Whether to use tag-level thresholds by default :type default_use_tag_thresholds: bool :param server_name: Server name for the Gradio app :type server_name: Optional[str] :param server_port: Server port for the Gradio app :type server_port: Optional[int] :param kwargs: Additional keyword arguments to pass to gr.launch() :type kwargs: Any :return: None :raises EnvironmentError: If Gradio is not installed """ _check_gradio_env() with gr.Blocks() as demo: with gr.Row(): Loading @@ -322,6 +465,19 @@ class MultiLabelTIMMModel: def _open_models_for_repo_id(repo_id: str, category_names: Optional[Tuple[Tuple[Any, str], ...]] = None, hf_token: Optional[str] = None) \ -> MultiLabelTIMMModel: """ Open and cache a MultiLabelTIMMModel for a given repository ID. :param repo_id: The Hugging Face Hub repository ID :type repo_id: str :param category_names: Optional tuple of (category_id, name) pairs for category naming :type category_names: Optional[Tuple[Tuple[Any, str], ...]] :param hf_token: Optional Hugging Face authentication token :type hf_token: Optional[str] :return: A cached MultiLabelTIMMModel instance :rtype: MultiLabelTIMMModel """ return MultiLabelTIMMModel( repo_id=repo_id, hf_token=hf_token, Loading @@ -333,6 +489,58 @@ def multilabel_timm_predict(image: ImageTyping, repo_id: str, category_names: Di preprocessor: Literal['test', 'val'] = 'test', thresholds: Union[float, Dict[Any, float]] = None, use_tag_thresholds: bool = True, fmt=FMT_UNSET, hf_token: Optional[str] = None): """ Make predictions using a multi-label TIMM model. This function provides a convenient interface for making predictions with models hosted on Hugging Face Hub without directly instantiating a MultiLabelTIMMModel. :param image: The input image :type image: ImageTyping :param repo_id: The Hugging Face Hub repository ID containing the model :type repo_id: str :param category_names: Optional mapping of category IDs to display names :type category_names: Dict[Any, str] :param preprocessor: Which preprocessor to use ('test' or 'val') :type preprocessor: Literal['test', 'val'] :param thresholds: Threshold values for tag confidence. Can be a single float applied to all categories or a dictionary mapping category IDs or names to threshold values :type thresholds: Union[float, Dict[Any, float]] :param use_tag_thresholds: Whether to use tag-level thresholds if available :type use_tag_thresholds: bool :param fmt: Output format specification. Can be a tuple of category names to include, or FMT_UNSET to use all categories :type fmt: Any :param hf_token: Optional Hugging Face authentication token for private repositories :type hf_token: Optional[str] :return: Formatted prediction results according to the fmt parameter :rtype: Any .. note:: The fmt argument can include the following keys: - Category names: dicts containing category-specific tags and their confidences - ``tag``: a dict containing all tags across categories and their confidences - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization - ``logits``: a 1-dim logits result of image. - ``prediction``: a 1-dim prediction result of image You can extract specific category predictions or all tags based on your needs. Example: >>> # Get all categories and tags >>> result = multilabel_timm_predict('image.jpg', 'username/model-name') >>> >>> # Get only specific categories >>> result = multilabel_timm_predict('image.jpg', 'username/model-name', ... fmt=('general', 'character')) >>> >>> # Get just the raw prediction values >>> prediction = multilabel_timm_predict('image.jpg', 'username/model-name', ... fmt='prediction') """ model = _open_models_for_repo_id( repo_id=repo_id, category_names=tuple((key, value) for key, value in sorted((category_names or {}).items())), Loading Loading
docs/source/api_doc/generic/index.rst +1 −0 Original line number Diff line number Diff line Loading @@ -12,5 +12,6 @@ imgutils.generic classify enhance clip multilabel_timm siglip yolo
docs/source/api_doc/generic/multilabel_timm.rst 0 → 100644 +23 −0 Original line number Diff line number Diff line imgutils.generic.multilabel_timm ======================================= .. currentmodule:: imgutils.generic.multilabel_timm .. automodule:: imgutils.generic.multilabel_timm MultiLabelTIMMModel ---------------------------------------------- .. autoclass:: MultiLabelTIMMModel :members: __init__,predict,make_ui,launch_demo multilabel_timm_predict -------------------------------------------------- .. autofunction:: multilabel_timm_predict
imgutils/generic/multilabel_timm.py +208 −0 Original line number Diff line number Diff line """ Multi-Label TIMM Model Module This module provides functionality for working with multi-label image classification models trained with TIMM (PyTorch Image Models) and exported to ONNX format. It includes: 1. The MultiLabelTIMMModel class for loading and making predictions with models hosted on Hugging Face Hub 2. Functions for batch prediction and demo interface creation 3. Support for custom thresholds at both category and tag levels 4. Flexible output formatting options for different use cases The models are expected to be stored on Hugging Face Hub with specific files: - model.onnx: The ONNX model file - selected_tags.csv: CSV file containing tag information and categories - preprocess.json: JSON configuration for image preprocessing - thresholds.csv: Optional CSV file with recommended thresholds This module is designed to work with multi-label classification tasks where images can belong to multiple categories and have multiple tags within each category. """ import io import json import os Loading Loading @@ -45,7 +67,31 @@ FMT_UNSET = object() class MultiLabelTIMMModel: """ A class for working with multi-label image classification models trained with TIMM. This class handles loading models from Hugging Face Hub, preprocessing images, and making predictions with customizable thresholds. :param repo_id: The Hugging Face Hub repository ID containing the model :type repo_id: str :param hf_token: Optional Hugging Face authentication token for private repositories :type hf_token: Optional[str] :param category_names: Optional mapping of category IDs to display names :type category_names: Dict[Any, str] """ def __init__(self, repo_id: str, hf_token: Optional[str] = None, category_names: Dict[Any, str] = None): """ Initialize a MultiLabelTIMMModel. :param repo_id: The Hugging Face Hub repository ID containing the model :type repo_id: str :param hf_token: Optional Hugging Face authentication token for private repositories :type hf_token: Optional[str] :param category_names: Optional mapping of category IDs to display names :type category_names: Dict[Any, str] """ self.repo_id = repo_id self._model = None self._df_tags = None Loading @@ -68,6 +114,12 @@ class MultiLabelTIMMModel: return self._hf_token or os.environ.get('HF_TOKEN') def _open_model(self): """ Load the ONNX model from Hugging Face Hub. :return: The loaded ONNX model :rtype: object """ with self._lock: if self._model is None: self._model = open_onnx_model(hf_hub_download( Loading @@ -80,6 +132,12 @@ class MultiLabelTIMMModel: return self._model def _open_tags(self): """ Load tag information from the Hugging Face Hub. :return: DataFrame containing tag information :rtype: pandas.DataFrame """ with self._lock: if self._df_tags is None: self._df_tags = pd.read_csv(hf_hub_download( Loading @@ -97,6 +155,12 @@ class MultiLabelTIMMModel: return self._df_tags def _open_preprocess(self): """ Load preprocessing configuration from the Hugging Face Hub. :return: A tuple of validation and test preprocessing transforms :rtype: tuple """ with self._lock: if self._preprocess is None: with open(hf_hub_download( Loading @@ -112,6 +176,12 @@ class MultiLabelTIMMModel: return self._preprocess def _open_default_category_thresholds(self): """ Load default category thresholds from the Hugging Face Hub. :return: Dictionary mapping category IDs to threshold values :rtype: dict """ with self._lock: if self._default_category_thresholds is None: try: Loading @@ -131,6 +201,18 @@ class MultiLabelTIMMModel: return self._default_category_thresholds def _raw_predict(self, image: ImageTyping, preprocessor: Literal['test', 'val'] = 'test'): """ Make a raw prediction with the model. :param image: The input image :type image: ImageTyping :param preprocessor: Which preprocessor to use ('test' or 'val') :type preprocessor: Literal['test', 'val'] :return: Dictionary of model outputs :rtype: dict :raises ValueError: If an unknown preprocessor is specified """ image = load_image(image, force_background='white', mode='RGB') model = self._open_model() Loading @@ -150,6 +232,39 @@ class MultiLabelTIMMModel: def predict(self, image: ImageTyping, preprocessor: Literal['test', 'val'] = 'test', thresholds: Union[float, Dict[Any, float]] = None, use_tag_thresholds: bool = True, fmt=FMT_UNSET): """ Make a prediction and format the results. This method processes an image through the model and applies thresholds to determine which tags to include in the results. The output format can be customized using the fmt parameter. :param image: The input image :type image: ImageTyping :param preprocessor: Which preprocessor to use ('test' or 'val') :type preprocessor: Literal['test', 'val'] :param thresholds: Threshold values for tag confidence. Can be a single float applied to all categories or a dictionary mapping category IDs or names to threshold values :type thresholds: Union[float, Dict[Any, float]] :param use_tag_thresholds: Whether to use tag-level thresholds if available :type use_tag_thresholds: bool :param fmt: Output format specification. Can be a tuple of category names to include, or FMT_UNSET to use all categories :type fmt: Any :return: Formatted prediction results according to the fmt parameter :rtype: Any .. note:: The fmt argument can include the following keys: - Category names: dicts containing category-specific tags and their confidences - ``tag``: a dict containing all tags across categories and their confidences - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization - ``logits``: a 1-dim logits result of image. - ``prediction``: a 1-dim prediction result of image You can extract specific category predictions or all tags based on your needs. """ df_tags = self._open_tags() values = self._raw_predict(image, preprocessor=preprocessor) prediction = values['prediction'] Loading Loading @@ -200,6 +315,17 @@ class MultiLabelTIMMModel: def make_ui(self, default_thresholds: Union[float, Dict[Any, float]] = None, default_use_tag_thresholds: bool = True): """ Create a Gradio UI for the model. :param default_thresholds: Default threshold values to use in the UI :type default_thresholds: Union[float, Dict[Any, float]] :param default_use_tag_thresholds: Whether to use tag-level thresholds by default :type default_use_tag_thresholds: bool :return: None :raises EnvironmentError: If Gradio is not installed """ _check_gradio_env() df_tags = self._open_tags() default_category_thresholds = self._open_default_category_thresholds() Loading Loading @@ -296,6 +422,23 @@ class MultiLabelTIMMModel: def launch_demo(self, default_thresholds: Union[float, Dict[Any, float]] = None, default_use_tag_thresholds: bool = True, server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs): """ Launch a Gradio demo for the model. :param default_thresholds: Default threshold values to use in the demo :type default_thresholds: Union[float, Dict[Any, float]] :param default_use_tag_thresholds: Whether to use tag-level thresholds by default :type default_use_tag_thresholds: bool :param server_name: Server name for the Gradio app :type server_name: Optional[str] :param server_port: Server port for the Gradio app :type server_port: Optional[int] :param kwargs: Additional keyword arguments to pass to gr.launch() :type kwargs: Any :return: None :raises EnvironmentError: If Gradio is not installed """ _check_gradio_env() with gr.Blocks() as demo: with gr.Row(): Loading @@ -322,6 +465,19 @@ class MultiLabelTIMMModel: def _open_models_for_repo_id(repo_id: str, category_names: Optional[Tuple[Tuple[Any, str], ...]] = None, hf_token: Optional[str] = None) \ -> MultiLabelTIMMModel: """ Open and cache a MultiLabelTIMMModel for a given repository ID. :param repo_id: The Hugging Face Hub repository ID :type repo_id: str :param category_names: Optional tuple of (category_id, name) pairs for category naming :type category_names: Optional[Tuple[Tuple[Any, str], ...]] :param hf_token: Optional Hugging Face authentication token :type hf_token: Optional[str] :return: A cached MultiLabelTIMMModel instance :rtype: MultiLabelTIMMModel """ return MultiLabelTIMMModel( repo_id=repo_id, hf_token=hf_token, Loading @@ -333,6 +489,58 @@ def multilabel_timm_predict(image: ImageTyping, repo_id: str, category_names: Di preprocessor: Literal['test', 'val'] = 'test', thresholds: Union[float, Dict[Any, float]] = None, use_tag_thresholds: bool = True, fmt=FMT_UNSET, hf_token: Optional[str] = None): """ Make predictions using a multi-label TIMM model. This function provides a convenient interface for making predictions with models hosted on Hugging Face Hub without directly instantiating a MultiLabelTIMMModel. :param image: The input image :type image: ImageTyping :param repo_id: The Hugging Face Hub repository ID containing the model :type repo_id: str :param category_names: Optional mapping of category IDs to display names :type category_names: Dict[Any, str] :param preprocessor: Which preprocessor to use ('test' or 'val') :type preprocessor: Literal['test', 'val'] :param thresholds: Threshold values for tag confidence. Can be a single float applied to all categories or a dictionary mapping category IDs or names to threshold values :type thresholds: Union[float, Dict[Any, float]] :param use_tag_thresholds: Whether to use tag-level thresholds if available :type use_tag_thresholds: bool :param fmt: Output format specification. Can be a tuple of category names to include, or FMT_UNSET to use all categories :type fmt: Any :param hf_token: Optional Hugging Face authentication token for private repositories :type hf_token: Optional[str] :return: Formatted prediction results according to the fmt parameter :rtype: Any .. note:: The fmt argument can include the following keys: - Category names: dicts containing category-specific tags and their confidences - ``tag``: a dict containing all tags across categories and their confidences - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization - ``logits``: a 1-dim logits result of image. - ``prediction``: a 1-dim prediction result of image You can extract specific category predictions or all tags based on your needs. Example: >>> # Get all categories and tags >>> result = multilabel_timm_predict('image.jpg', 'username/model-name') >>> >>> # Get only specific categories >>> result = multilabel_timm_predict('image.jpg', 'username/model-name', ... fmt=('general', 'character')) >>> >>> # Get just the raw prediction values >>> prediction = multilabel_timm_predict('image.jpg', 'username/model-name', ... fmt='prediction') """ model = _open_models_for_repo_id( repo_id=repo_id, category_names=tuple((key, value) for key, value in sorted((category_names or {}).items())), Loading