Loading imgutils/tagging/camie.py +66 −14 Original line number Diff line number Diff line """ This module provides functionality for image tagging using the Camie model from Hugging Face Hub. It includes tools for loading models, processing images, and extracting tags across different categories like rating, general tags, characters, artists, and more. """ import json from typing import List, Tuple, Dict, Optional, Union, Literal, Any Loading Loading @@ -27,11 +32,13 @@ _CATEGORY_MAPS = { @ts_lru_cache() def _get_camie_model(model_name, is_full: bool): """ Load an ONNX model from the Hugging Face Hub. Load and cache a Camie ONNX model from the Hugging Face Hub. :param model_name: The name of the model to load. :param model_name: Name of the model to load :type model_name: str :return: The loaded ONNX model. :param is_full: Whether to load the full model or initial-only version :type is_full: bool :return: Loaded ONNX model :rtype: ONNXModel """ return open_onnx_model(hf_hub_download( Loading @@ -44,13 +51,13 @@ def _get_camie_model(model_name, is_full: bool): @ts_lru_cache() def _get_camie_labels(model_name, no_underline: bool = False) -> Tuple[List[str], Dict[str, List[int]]]: """ Get labels for the Camie model. Retrieve and process labels for the Camie model. :param model_name: The name of the model. :param model_name: Name of the model :type model_name: str :param no_underline: If True, replaces underscores in tag names with spaces. :param no_underline: If True, replace underscores with spaces in tag names :type no_underline: bool :return: A tuple containing the list of tag names, and lists of indexes for rating, general, and character categories. :return: Tuple of (list of tag names, dictionary mapping category names to their indices) :rtype: Tuple[List[str], Dict[str, List[int]]] """ path = hf_hub_download( Loading @@ -73,6 +80,14 @@ def _get_camie_labels(model_name, no_underline: bool = False) -> Tuple[List[str] @ts_lru_cache() def _get_camie_preprocessor(model_name: str): """ Get the image preprocessor for the specified Camie model. :param model_name: Name of the model :type model_name: str :return: Pillow transform pipeline for image preprocessing :rtype: callable """ with open(hf_hub_download( repo_id=_REPO_ID, repo_type='model', Loading @@ -87,6 +102,16 @@ CamieModeTyping = Literal['balanced', 'high_precision', 'high_recall', 'micro_op @ts_lru_cache() def _get_camie_threshold(model_name: str, mode: CamieModeTyping = 'balanced'): """ Get threshold values for different categories based on the specified mode. :param model_name: Name of the model :type model_name: str :param mode: Prediction mode affecting threshold values :type mode: CamieModeTyping :return: Dictionary of thresholds for each category :rtype: Dict[str, float] """ with open(hf_hub_download( repo_id=_REPO_ID, repo_type='model', Loading @@ -108,19 +133,26 @@ def _postprocess_embedding_values( drop_overlap: bool = False, ): """ Post-process the embedding and prediction results. Post-process model predictions and embeddings into structured tag results. :param pred: The prediction array. :param pred: Raw prediction array from the model :type pred: numpy.ndarray :param embedding: The embedding array. :param logits: Logits array from the model :type logits: numpy.ndarray :param embedding: Embedding array from the model :type embedding: numpy.ndarray :param model_name: The name of the model used. :param model_name: Name of the model used :type model_name: str :param no_underline: Whether to remove underscores from tag names. :param mode: Prediction mode for threshold selection :type mode: CamieModeTyping :param thresholds: Custom thresholds for tag selection :type thresholds: Optional[Union[float, Dict[str, float]]] :param no_underline: Whether to remove underscores from tag names :type no_underline: bool :param drop_overlap: Whether to drop overlapping tags. :param drop_overlap: Whether to remove overlapping tags :type drop_overlap: bool :return: The post-processed results. :return: Dictionary containing processed predictions and embeddings :rtype: Dict[str, Any] """ assert len(pred.shape) == len(embedding.shape) == 1, \ f'Both pred and embeddings shapes should be 1-dim, ' \ Loading Loading @@ -173,6 +205,26 @@ def get_camie_tags( drop_overlap: bool = False, fmt: Any = ('rating', 'general', 'character'), ): """ Extract tags from an image using the Camie model. :param image: Input image (can be path, URL, or image data) :type image: ImageTyping :param model_name: Name of the Camie model to use :type model_name: str :param mode: Prediction mode affecting threshold values :type mode: CamieModeTyping :param thresholds: Custom thresholds for tag selection :type thresholds: Optional[Union[float, Dict[str, float]]] :param no_underline: Whether to remove underscores from tag names :type no_underline: bool :param drop_overlap: Whether to remove overlapping tags :type drop_overlap: bool :param fmt: Format specification for output :type fmt: Any :return: Dictionary of extracted tags and embeddings :rtype: Dict[str, Any] """ names = vnames(fmt) need_full = False for name in names: Loading Loading
imgutils/tagging/camie.py +66 −14 Original line number Diff line number Diff line """ This module provides functionality for image tagging using the Camie model from Hugging Face Hub. It includes tools for loading models, processing images, and extracting tags across different categories like rating, general tags, characters, artists, and more. """ import json from typing import List, Tuple, Dict, Optional, Union, Literal, Any Loading Loading @@ -27,11 +32,13 @@ _CATEGORY_MAPS = { @ts_lru_cache() def _get_camie_model(model_name, is_full: bool): """ Load an ONNX model from the Hugging Face Hub. Load and cache a Camie ONNX model from the Hugging Face Hub. :param model_name: The name of the model to load. :param model_name: Name of the model to load :type model_name: str :return: The loaded ONNX model. :param is_full: Whether to load the full model or initial-only version :type is_full: bool :return: Loaded ONNX model :rtype: ONNXModel """ return open_onnx_model(hf_hub_download( Loading @@ -44,13 +51,13 @@ def _get_camie_model(model_name, is_full: bool): @ts_lru_cache() def _get_camie_labels(model_name, no_underline: bool = False) -> Tuple[List[str], Dict[str, List[int]]]: """ Get labels for the Camie model. Retrieve and process labels for the Camie model. :param model_name: The name of the model. :param model_name: Name of the model :type model_name: str :param no_underline: If True, replaces underscores in tag names with spaces. :param no_underline: If True, replace underscores with spaces in tag names :type no_underline: bool :return: A tuple containing the list of tag names, and lists of indexes for rating, general, and character categories. :return: Tuple of (list of tag names, dictionary mapping category names to their indices) :rtype: Tuple[List[str], Dict[str, List[int]]] """ path = hf_hub_download( Loading @@ -73,6 +80,14 @@ def _get_camie_labels(model_name, no_underline: bool = False) -> Tuple[List[str] @ts_lru_cache() def _get_camie_preprocessor(model_name: str): """ Get the image preprocessor for the specified Camie model. :param model_name: Name of the model :type model_name: str :return: Pillow transform pipeline for image preprocessing :rtype: callable """ with open(hf_hub_download( repo_id=_REPO_ID, repo_type='model', Loading @@ -87,6 +102,16 @@ CamieModeTyping = Literal['balanced', 'high_precision', 'high_recall', 'micro_op @ts_lru_cache() def _get_camie_threshold(model_name: str, mode: CamieModeTyping = 'balanced'): """ Get threshold values for different categories based on the specified mode. :param model_name: Name of the model :type model_name: str :param mode: Prediction mode affecting threshold values :type mode: CamieModeTyping :return: Dictionary of thresholds for each category :rtype: Dict[str, float] """ with open(hf_hub_download( repo_id=_REPO_ID, repo_type='model', Loading @@ -108,19 +133,26 @@ def _postprocess_embedding_values( drop_overlap: bool = False, ): """ Post-process the embedding and prediction results. Post-process model predictions and embeddings into structured tag results. :param pred: The prediction array. :param pred: Raw prediction array from the model :type pred: numpy.ndarray :param embedding: The embedding array. :param logits: Logits array from the model :type logits: numpy.ndarray :param embedding: Embedding array from the model :type embedding: numpy.ndarray :param model_name: The name of the model used. :param model_name: Name of the model used :type model_name: str :param no_underline: Whether to remove underscores from tag names. :param mode: Prediction mode for threshold selection :type mode: CamieModeTyping :param thresholds: Custom thresholds for tag selection :type thresholds: Optional[Union[float, Dict[str, float]]] :param no_underline: Whether to remove underscores from tag names :type no_underline: bool :param drop_overlap: Whether to drop overlapping tags. :param drop_overlap: Whether to remove overlapping tags :type drop_overlap: bool :return: The post-processed results. :return: Dictionary containing processed predictions and embeddings :rtype: Dict[str, Any] """ assert len(pred.shape) == len(embedding.shape) == 1, \ f'Both pred and embeddings shapes should be 1-dim, ' \ Loading Loading @@ -173,6 +205,26 @@ def get_camie_tags( drop_overlap: bool = False, fmt: Any = ('rating', 'general', 'character'), ): """ Extract tags from an image using the Camie model. :param image: Input image (can be path, URL, or image data) :type image: ImageTyping :param model_name: Name of the Camie model to use :type model_name: str :param mode: Prediction mode affecting threshold values :type mode: CamieModeTyping :param thresholds: Custom thresholds for tag selection :type thresholds: Optional[Union[float, Dict[str, float]]] :param no_underline: Whether to remove underscores from tag names :type no_underline: bool :param drop_overlap: Whether to remove overlapping tags :type drop_overlap: bool :param fmt: Format specification for output :type fmt: Any :return: Dictionary of extracted tags and embeddings :rtype: Dict[str, Any] """ names = vnames(fmt) need_full = False for name in names: Loading