Commit bed0f541 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add yolo generic

parent 195cc8f4
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -4,3 +4,4 @@ Overview:
"""
from .classify import *
from .enhance import *
from .yolo import *
+147 −5
Original line number Diff line number Diff line
"""
This module provides functionality for YOLO object detection using ONNX models from Hugging Face.

It includes utilities for preprocessing images, performing object detection, and post-processing
the results. The main components are:

1. YOLOModel class: Manages YOLO models from a Hugging Face repository.
2. Helper functions for coordinate conversion, non-maximum suppression, and image processing.
3. A high-level function 'yolo_predict' for easy object detection on images.

The module supports various image input types and allows customization of confidence and IoU thresholds.
"""

import ast
import json
import math
@@ -10,10 +23,14 @@ from PIL import Image
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download

from imgutils.data import load_image, rgb_encode
from ..data import ImageTyping
from ..data import load_image, rgb_encode, ImageTyping
from ..utils import open_onnx_model

__all__ = [
    'YOLOModel',
    'yolo_predict',
]


def _yolo_xywh2xyxy(x: np.ndarray) -> np.ndarray:
    """
@@ -250,17 +267,54 @@ def _safe_eval_names_str(names_str):


class YOLOModel:
    """
    A class to manage YOLO models from a Hugging Face repository.

    This class handles the loading, caching, and inference of YOLO models.

    :param repo_id: The Hugging Face repository ID containing the YOLO models.
    :type repo_id: str
    :param hf_token: Optional Hugging Face authentication token.
    :type hf_token: Optional[str]

    :Example:

    >>> model = YOLOModel("username/repo_name")
    >>> image = Image.open("path/to/image.jpg")
    >>> detections = model.predict(image, "model_name")
    """

    def __init__(self, repo_id: str, hf_token: Optional[str] = None):
        """
        Initialize the YOLOModel.

        :param repo_id: The Hugging Face repository ID containing the YOLO models.
        :type repo_id: str
        :param hf_token: Optional Hugging Face authentication token.
        :type hf_token: Optional[str]
        """
        self.repo_id = repo_id
        self._model_names = None
        self._models = {}
        self._hf_token = hf_token

    def _get_hf_token(self):
        """
        Get the Hugging Face token, either from the instance or environment variable.

        :return: Hugging Face token.
        :rtype: Optional[str]
        """
        return self._hf_token or os.environ.get('HF_TOKEN')

    @property
    def model_names(self) -> List[str]:
        """
        Get the list of available model names in the repository.

        :return: List of model names.
        :rtype: List[str]
        """
        if self._model_names is None:
            hf_fs = HfFileSystem(token=self._get_hf_token())
            self._model_names = [
@@ -275,11 +329,26 @@ class YOLOModel:
        return self._model_names

    def _check_model_name(self, model_name: str):
        """
        Check if the given model name is valid for this repository.

        :param model_name: Name of the model to check.
        :type model_name: str
        :raises ValueError: If the model name is not found in the repository.
        """
        if model_name not in self.model_names:
            raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
                             f'models {self.model_names!r} are available.')

    def _open_model(self, model_name: str):
        """
        Open and cache a YOLO model.

        :param model_name: Name of the model to open.
        :type model_name: str
        :return: Tuple containing the ONNX model, maximum inference size, and labels.
        :rtype: tuple
        """
        if model_name not in self._models:
            self._check_model_name(model_name)
            model = open_onnx_model(hf_hub_download(
@@ -293,9 +362,7 @@ class YOLOModel:
            else:
                max_infer_size = 640
            names_map = _safe_eval_names_str(model_metadata.custom_metadata_map['names'])
            labels = ['<unknown>'] * (max(names_map.keys()) + 1)
            for id_, name in names_map.items():
                labels[id_] = name
            labels = [names_map[i] for i in range(len(names_map))]
            self._models[model_name] = (model, max_infer_size, labels)

        return self._models[model_name]
@@ -303,6 +370,29 @@ class YOLOModel:
    def predict(self, image: ImageTyping, model_name: str,
                conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
            -> List[Tuple[Tuple[int, int, int, int], str, float]]:
        """
        Perform object detection on an image using the specified YOLO model.

        :param image: Input image for object detection.
        :type image: ImageTyping
        :param model_name: Name of the YOLO model to use.
        :type model_name: str
        :param conf_threshold: Confidence threshold for filtering detections. Default is 0.25.
        :type conf_threshold: float
        :param iou_threshold: IoU threshold for non-maximum suppression. Default is 0.7.
        :type iou_threshold: float

        :return: List of detections, each in the format ((x0, y0, x1, y1), label, confidence).
        :rtype: List[Tuple[Tuple[int, int, int, int], str, float]]

        :Example:

        >>> model = YOLOModel("username/repo_name")
        >>> image = Image.open("path/to/image.jpg")
        >>> detections = model.predict(image, "model_name")
        >>> print(detections[0])  # First detection
        ((100, 200, 300, 400), 'person', 0.95)
        """
        model, max_infer_size, labels = self._open_model(model_name)
        image = load_image(image, mode='RGB')
        new_image, old_size, new_size = _image_preprocess(image, max_infer_size)
@@ -311,17 +401,69 @@ class YOLOModel:
        return _data_postprocess(output[0], conf_threshold, iou_threshold, old_size, new_size, labels)

    def clear(self):
        """
        Clear cached model and metadata.
        """
        self._models.clear()


@lru_cache()
def _open_models_for_repo_id(repo_id: str) -> YOLOModel:
    """
    Load and cache a YOLO model from a Hugging Face repository.

    This function uses the `lru_cache` decorator to cache the loaded models,
    improving performance for repeated calls with the same repository ID.

    :param repo_id: The Hugging Face repository ID for the YOLO model.
    :type repo_id: str

    :return: The loaded YOLO model.
    :rtype: YOLOModel

    :raises Exception: If there's an error loading the model from the repository.

    Usage:
        >>> model = _open_models_for_repo_id("yolov5/yolov5s")
        >>> # Subsequent calls with the same repo_id will return the cached model
        >>> same_model = _open_models_for_repo_id("yolov5/yolov5s")
    """
    return YOLOModel(repo_id)


def yolo_predict(image: ImageTyping, repo_id: str, model_name: str,
                 conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
    Perform object detection on an image using a YOLO model from a Hugging Face repository.

    This function is a high-level wrapper around the YOLOModel class, providing a simple
    interface for object detection without needing to explicitly manage model instances.

    :param image: Input image for object detection.
    :type image: ImageTyping
    :param repo_id: The Hugging Face repository ID containing the YOLO models.
    :type repo_id: str
    :param model_name: Name of the YOLO model to use.
    :type model_name: str
    :param conf_threshold: Confidence threshold for filtering detections. Default is 0.25.
    :type conf_threshold: float
    :param iou_threshold: IoU threshold for non-maximum suppression. Default is 0.7.
    :type iou_threshold: float
    :param hf_token: Optional Hugging Face authentication token.
    :type hf_token: Optional[str]

    :return: List of detections, each in the format ((x0, y0, x1, y1), label, confidence).
    :rtype: List[Tuple[Tuple[int, int, int, int], str, float]]

    :Example:

    >>> from PIL import Image
    >>> image = Image.open("path/to/image.jpg")
    >>> detections = yolo_predict(image, "username/repo_name", "model_name")
    >>> print(detections[0])  # First detection
    ((100, 200, 300, 400), 'person', 0.95)
    """
    return _open_models_for_repo_id(repo_id).predict(
        image=image,
        model_name=model_name,