Commit 6cbeda91 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add docstring for imgutils.generic.attachment

parent 4200c013
Loading
Loading
Loading
Loading
+23 −0
Original line number Diff line number Diff line
imgutils.generic.attachment
=======================================

.. currentmodule:: imgutils.generic.attachment

.. automodule:: imgutils.generic.attachment



Attachment
-----------------------------------------

.. autoclass:: Attachment
    :members: __init__, encoder_model, predict


open_attachment
-----------------------------------------

.. autofunction:: open_attachment


+1 −0
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ imgutils.generic
.. toctree::
    :maxdepth: 3

    attachment
    classify
    enhance
    clip
+117 −3
Original line number Diff line number Diff line
"""
This module provides functionality for handling attachments in machine learning models,
particularly those hosted on Hugging Face's model hub. It includes tools for loading,
managing and making predictions with ONNX models for classification, tagging and regression tasks.

The module provides a caching mechanism for model loading and thread-safe operations
for concurrent access to models and their metadata.

An example of attachment models is `deepghs/eattach_monochrome_experiments <https://huggingface.co/deepghs/eattach_monochrome_experiments>`_.

.. note::
    If you want to train a custom attachment model for taggers,
    take a look at our framework `deepghs/emb_attachments <https://github.com/deepghs/emb_attachments>`_.
"""

import json
import os
from threading import Lock
from typing import Optional, Any
from typing import Optional, Any, Tuple

import numpy as np
from huggingface_hub import hf_hub_download

from imgutils.utils import open_onnx_model, vreplace, ts_lru_cache
from ..utils import open_onnx_model, vreplace, ts_lru_cache


class Attachment:
    """
    A class to manage machine learning model attachments from Hugging Face.

    This class handles model loading, caching, and prediction for various types of problems
    including classification, tagging, and regression.

    :param repo_id: The Hugging Face repository ID
    :type repo_id: str
    :param model_name: Name of the model
    :type model_name: str
    :param hf_token: Optional Hugging Face authentication token
    :type hf_token: Optional[str]
    """

    def __init__(self, repo_id: str, model_name: str, hf_token: Optional[str] = None):
        """
        Initialize the Attachment instance with repository and model information.
        """
        self.repo_id = repo_id
        self.model_name = model_name
        self._meta_value = None
@@ -33,6 +65,12 @@ class Attachment:

    @property
    def _meta(self):
        """
        Load and cache model metadata from the Hugging Face repository.

        :return: Model metadata as a dictionary
        :rtype: dict
        """
        with self._model_lock:
            if self._meta_value is None:
                with open(hf_hub_download(
@@ -47,9 +85,21 @@ class Attachment:

    @property
    def encoder_model(self) -> str:
        """
        Get the encoder model name from metadata.

        :return: Name of the encoder model
        :rtype: str
        """
        return self._meta['encoder_model']

    def _open_model(self):
        """
        Load and cache the ONNX model from Hugging Face.

        :return: Loaded ONNX model
        :rtype: object
        """
        with self._model_lock:
            if self._model is None:
                self._model = open_onnx_model(hf_hub_download(
@@ -61,12 +111,30 @@ class Attachment:

        return self._model

    def _predict_raw(self, embedding: np.ndarray):
    def _predict_raw(self, embedding: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Make raw predictions using the model.

        :param embedding: Input embedding array
        :type embedding: np.ndarray
        :return: Tuple of logits and predictions
        :rtype: Tuple[np.ndarray, np.ndarray]
        """
        model = self._open_model()
        logits, prediction = model.run(['logits', 'prediction'], {'input': embedding})
        return logits, prediction

    def _predict_classification(self, embedding: np.ndarray, fmt: Any = 'top'):
        """
        Make classification predictions.

        :param embedding: Input embedding array
        :type embedding: np.ndarray
        :param fmt: Format specification for output
        :type fmt: Any
        :return: List of formatted prediction results
        :rtype: list
        """
        labels = np.array(self._meta['problem']['labels'])
        logits, prediction = self._predict_raw(embedding)
        retval = []
@@ -87,6 +155,18 @@ class Attachment:
        return retval

    def _predict_tagging(self, embedding: np.ndarray, threshold: float = 0.3, fmt: Any = 'tags'):
        """
        Make tagging predictions.

        :param embedding: Input embedding array
        :type embedding: np.ndarray
        :param threshold: Confidence threshold for tag selection
        :type threshold: float
        :param fmt: Format specification for output
        :type fmt: Any
        :return: List of formatted prediction results
        :rtype: list
        """
        tags = np.array(self._meta['problem']['tags'])
        logits, prediction = self._predict_raw(embedding)
        retval = []
@@ -103,6 +183,16 @@ class Attachment:
        return retval

    def _predict_regression(self, embedding: np.ndarray, fmt: Any = 'full'):
        """
        Make regression predictions.

        :param embedding: Input embedding array
        :type embedding: np.ndarray
        :param fmt: Format specification for output
        :type fmt: Any
        :return: List of formatted prediction results
        :rtype: list
        """
        field_names = [name for name, _, _ in self._meta['problem']['fields']]
        logits, prediction = self._predict_raw(embedding)
        retval = []
@@ -118,6 +208,15 @@ class Attachment:
        return retval

    def predict(self, embedding: np.ndarray, **kwargs):
        """
        Make predictions based on the problem type (classification, tagging, or regression).

        :param embedding: Input embedding array
        :type embedding: np.ndarray
        :param kwargs: Additional arguments passed to specific prediction methods
        :return: Prediction results in specified format
        :raises ValueError: If embedding shape is invalid or problem type is unknown
        """
        embedding = embedding.astype(np.float32)
        if len(embedding.shape) == 1:
            single = True
@@ -144,6 +243,21 @@ class Attachment:

@ts_lru_cache()
def open_attachment(repo_id: str, model_name: str, hf_token: Optional[str] = None) -> 'Attachment':
    """
    Create and cache an Attachment instance.

    This function creates a new Attachment instance or returns a cached one
    if it was previously created with the same parameters.

    :param repo_id: The Hugging Face repository ID
    :type repo_id: str
    :param model_name: Name of the model
    :type model_name: str
    :param hf_token: Optional Hugging Face authentication token
    :type hf_token: Optional[str]
    :return: An Attachment instance
    :rtype: Attachment
    """
    return Attachment(
        repo_id=repo_id,
        model_name=model_name,