Commit db8a1bae authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add pixai docs

parent 313a6536
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ imgutils.tagging
    mldanbooru
    wd14
    camie
    pixai
    deepdanbooru
    deepgelbooru
    format
+15 −0
Original line number Diff line number Diff line
imgutils.tagging.pixai
====================================

.. currentmodule:: imgutils.tagging.pixai

.. automodule:: imgutils.tagging.pixai


get_pixai_tags
----------------------

.. autofunction:: get_pixai_tags


+148 −39
Original line number Diff line number Diff line
"""
Overview:
    This module provides utilities for image tagging using PixAI taggers, which are specialized models
    for analyzing anime-style images and extracting relevant tags. The module supports loading ONNX
    models from Hugging Face Hub and processing images to generate categorized tags with confidence scores.

    The models are originally developed by the PixAI team and available at 
    `pixai-labs <https://huggingface.co/pixai-labs>`_ on Hugging Face. This module uses ONNX-converted
    versions of these models for efficient inference, available at 
    `deepghs <https://huggingface.co/deepghs>`_ repositories.

    Example::
        >>> from imgutils.tagging.pixai import get_pixai_tags
        >>> # Get tags with default thresholds
        >>> result = get_pixai_tags('path/to/anime_image.jpg', model_name='v0.9')
        >>> general_tags, character_tags = result
        >>> print("General tags:", general_tags)
        >>> print("Character tags:", character_tags)

        >>> # Get all tags in a single dictionary
        >>> all_tags = get_pixai_tags('path/to/image.jpg', fmt='tag')
        >>> print("All tags:", all_tags)
"""

import json
from typing import Union, Dict, Any
from typing import Union, Dict, Any, Tuple

import pandas as pd
from hbutils.design import SingletonMark
@@ -14,6 +38,21 @@ FMT_UNSET = SingletonMark('FMT_UNSET')


def _get_repo_id(model_name: str) -> str:
    """
    Get the repository ID for the specified model name.

    :param model_name: Name of the model (e.g., 'v0.9') or full repository path
    :type model_name: str

    :return: Full repository ID for Hugging Face Hub
    :rtype: str

    Example::
        >>> _get_repo_id('v0.9')
        'deepghs/pixai-tagger-v0.9-onnx'
        >>> _get_repo_id('custom/model-repo')
        'custom/model-repo'
    """
    if '/' in model_name:
        return model_name
    else:
@@ -23,10 +62,16 @@ def _get_repo_id(model_name: str) -> str:
@ts_lru_cache()
def _open_onnx_model(model_name: str):
    """
    Load the ONNX model from Hugging Face Hub.
    Load the ONNX model from Hugging Face Hub with caching.

    :return: The loaded ONNX model
    :rtype: object
    This function downloads and loads the ONNX model file for the specified PixAI tagger.
    Results are cached to avoid repeated downloads and model loading.

    :param model_name: Name of the model to load
    :type model_name: str

    :return: The loaded ONNX model session
    :rtype: onnxruntime.InferenceSession
    """
    return open_onnx_model(hf_hub_download(
        repo_id=_get_repo_id(model_name),
@@ -37,6 +82,19 @@ def _open_onnx_model(model_name: str):

@ts_lru_cache()
def _open_tags(model_name: str) -> pd.DataFrame:
    """
    Load the tag metadata from Hugging Face Hub with caching.

    This function downloads and loads the CSV file containing tag names and categories
    for the specified model. The DataFrame contains columns for tag names, categories,
    and other metadata.

    :param model_name: Name of the model
    :type model_name: str

    :return: DataFrame containing tag information with columns like 'name', 'category'
    :rtype: pd.DataFrame
    """
    return pd.read_csv(hf_hub_download(
        repo_id=_get_repo_id(model_name),
        repo_type='model',
@@ -46,6 +104,17 @@ def _open_tags(model_name: str) -> pd.DataFrame:

@ts_lru_cache()
def _open_preprocess(model_name: str):
    """
    Load the preprocessing pipeline configuration from Hugging Face Hub with caching.

    This function downloads the preprocessing configuration and creates a PIL transforms
    pipeline for image preprocessing before model inference.

    :param model_name: Name of the model
    :type model_name: str

    :return: Preprocessing transform pipeline
    """
    with open(hf_hub_download(
            repo_id=_get_repo_id(model_name),
            repo_type='model',
@@ -56,12 +125,23 @@ def _open_preprocess(model_name: str):


@ts_lru_cache()
def _open_default_category_thresholds(model_name: str) -> Union[Dict[int, float], Dict[int, str]]:
def _open_default_category_thresholds(model_name: str) -> Tuple[Dict[int, float], Dict[int, str]]:
    """
    Load default category thresholds from the Hugging Face Hub.
    Load default category thresholds and names from the Hugging Face Hub with caching.

    :return: Dictionary mapping category IDs to threshold values
    :rtype: dict
    This function attempts to load predefined threshold values for each category from
    a CSV file. If the file doesn't exist, empty dictionaries are returned.

    :param model_name: Name of the model
    :type model_name: str

    :return: Tuple containing (category_thresholds, category_names) dictionaries
    :rtype: tuple[Dict[int, float], Dict[int, str]]

    Example::
        >>> thresholds, names = _open_default_category_thresholds('v0.9')
        >>> print(thresholds)  # {0: 0.35, 1: 0.4, ...}
        >>> print(names)      # {0: 'general', 1: 'character', ...}
    """
    _default_category_thresholds: Dict[int, float] = {}
    _category_names: Dict[int, str] = {}
@@ -84,16 +164,23 @@ def _open_default_category_thresholds(model_name: str) -> Union[Dict[int, float]

def _raw_predict(image: ImageTyping, model_name: str):
    """
    Make a raw prediction with the model.
    Make a raw prediction with the PixAI tagger model.

    :param image: The input image
    This function preprocesses the input image and runs inference using the specified
    ONNX model. It returns the raw model outputs without any post-processing or
    threshold application.

    :param image: The input image to analyze
    :type image: ImageTyping
    :param preprocessor: Which preprocessor to use ('test' or 'val')
    :type preprocessor: Literal['test', 'val']
    :param model_name: Name of the model to use for prediction
    :type model_name: str

    :return: Dictionary of model outputs
    :return: Dictionary containing raw model outputs with keys like 'prediction', 'embedding', 'logits'
    :rtype: dict
    :raises ValueError: If an unknown preprocessor is specified

    Example::
        >>> raw_output = _raw_predict('anime_image.jpg', 'v0.9')
        >>> print(raw_output.keys())  # dict_keys(['prediction', 'embedding', 'logits'])
    """
    image = load_image(image, force_background='white', mode='RGB')
    model = _open_onnx_model(model_name=model_name)
@@ -107,39 +194,61 @@ def _raw_predict(image: ImageTyping, model_name: str):
def get_pixai_tags(image: ImageTyping, model_name: str = 'v0.9',
                   thresholds: Union[float, Dict[Any, float]] = None, fmt=FMT_UNSET):
    """
    Make a prediction and format the results.
    Extract tags from an image using PixAI tagger models.

    This method processes an image through the model and applies thresholds to determine
    which tags to include in the results. The output format can be customized using the fmt parameter.
    This function processes an image through a PixAI tagger model and applies confidence
    thresholds to determine which tags to include in the results. The output format can
    be customized to return specific categories or all tags together.

    :param image: The input image
    :param image: The input image to analyze (file path, PIL Image, numpy array, etc.)
    :type image: ImageTyping
    :param preprocessor: Which preprocessor to use ('test' or 'val')
    :type preprocessor: Literal['test', 'val']
    :param thresholds: Threshold values for tag confidence. Can be a single float applied to all categories
                      or a dictionary mapping category IDs or names to threshold values
    :type thresholds: Union[float, Dict[Any, float]]
    :param use_tag_thresholds: Whether to use tag-level thresholds if available
    :type use_tag_thresholds: bool
    :param fmt: Output format specification. Can be a tuple of category names to include,
               or FMT_UNSET to use all categories
    :param model_name: Name or path of the PixAI tagger model to use
    :type model_name: str
    :param thresholds: Confidence threshold values. Can be a single float applied to all 
                      categories, or a dictionary mapping category IDs/names to specific thresholds
    :type thresholds: Union[float, Dict[Any, float]], optional
    :param fmt: Output format specification. If FMT_UNSET, returns all available categories.
               Can be a tuple of category names to include in output
    :type fmt: Any

    :return: Formatted prediction results according to the fmt parameter
    :return: Formatted prediction results. Default returns tuple of (general_tags, character_tags, ...)
            based on available categories. Can return custom format based on fmt parameter
    :rtype: Any

    .. note::
        The fmt argument can include the following keys:

        - Category names: dicts containing category-specific tags and their confidences
        - ``tag``: a dict containing all tags across categories and their confidences
        - ``embedding``: a 1-dim embedding of image, recommended for index building after L2 normalization
        - ``logits``: a 1-dim logits result of image
        - ``prediction``: a 1-dim prediction result of image

        You can extract specific category predictions or all tags based on your needs.

    For more details see documentation of :func:`multilabel_timm_predict`.
        The fmt parameter can include the following keys:

        - Category names (e.g., 'general', 'character'): dictionaries containing category-specific 
          tags and their confidence scores
        - ``tag``: a dictionary containing all tags across categories and their confidences
        - ``embedding``: a 1-dimensional embedding vector of the image, recommended for similarity 
          search after L2 normalization
        - ``logits``: raw 1-dimensional logits output from the model
        - ``prediction``: 1-dimensional prediction probabilities from the model

        Default category thresholds are used if not specified. These vary by model and category
        but typically range from 0.35 to 0.5.

    Example::
        >>> from imgutils.tagging.pixai import get_pixai_tags
        >>> 
        >>> # Get tags with default format (all categories)
        >>> general_tags, character_tags = get_pixai_tags('anime_image.jpg', model_name='v0.9')
        >>> print("General tags:", general_tags)
        >>> print("Character tags:", character_tags)
        >>> 
        >>> # Get all tags in a single dictionary
        >>> all_tags = get_pixai_tags('image.jpg', fmt='tag')
        >>> print("All tags:", all_tags)
        >>> 
        >>> # Use custom thresholds
        >>> result = get_pixai_tags('image.jpg', thresholds={'general': 0.3, 'character': 0.5})
        >>> 
        >>> # Get embedding for similarity search
        >>> embedding = get_pixai_tags('image.jpg', fmt='embedding')
        >>> # Normalize for cosine similarity
        >>> import numpy as np
        >>> normalized_embedding = embedding / np.linalg.norm(embedding)
    """
    df_tags = _open_tags(model_name=model_name)
    values = _raw_predict(image, model_name=model_name)