Unverified Commit 4d530a27 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #108 from deepghs/dev/gradio

dev(narugo): add quick gradio demo for classifiers/yolos
parents 75084615 dbdba352
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ ClassifyModel
-----------------------------------------

.. autoclass:: ClassifyModel
    :members: __init__, predict_score, predict, clear
    :members: __init__, predict_score, predict, clear, make_ui, launch_demo



+1 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ YOLOModel
-----------------------------------------

.. autoclass:: YOLOModel
    :members: __init__, predict, clear
    :members: __init__, predict, clear, make_ui, launch_demo



+112 −0
Original line number Diff line number Diff line
@@ -23,12 +23,19 @@ from typing import Tuple, Optional, List, Dict

import numpy as np
from PIL import Image
from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import hf_hub_download, HfFileSystem

from ..data import rgb_encode, ImageTyping, load_image
from ..utils import open_onnx_model

try:
    import gradio as gr
except (ImportError, ModuleNotFoundError):
    gr = None

__all__ = [
    'ClassifyModel',
    'classify_predict_score',
@@ -36,6 +43,17 @@ __all__ = [
]


def _check_gradio_env():
    """
    Check if the Gradio library is installed and available.

    :raises EnvironmentError: If Gradio is not installed.
    """
    if gr is None:
        raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'
                               f'Please install it with `pip install dghs-imgutils[demo]`.')


def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
                normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
    """
@@ -287,6 +305,100 @@ class ClassifyModel:
        self._models.clear()
        self._labels.clear()

    def make_ui(self, default_model_name: Optional[str] = None):
        """
        Create the user interface components for the classifier model demo.

        This method sets up the Gradio UI components including an image input, model selection dropdown,
        submit button, and output label. It also configures the interaction between these components.

        :param default_model_name: The name of the default model to be selected in the dropdown.
                                   If None, the most recently updated model will be selected.
        :type default_model_name: Optional[str]

        :raises ImportError: If Gradio is not installed or properly configured.

        :Example:
        >>> model = ClassifyModel("username/repo_name")
        >>> model.make_ui(default_model_name="model_v1")
        """

        # demo for classifier model
        _check_gradio_env()
        model_list = self.model_names
        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
            selected_model_name, selected_time = None, None
            for fileitem in hf_client.get_paths_info(
                    repo_id=self.repo_id,
                    repo_type='model',
                    paths=[f'{model_name}/model.onnx' for model_name in model_list],
                    expand=True,
            ):
                if not selected_time or fileitem.last_commit.date > selected_time:
                    selected_model_name = os.path.dirname(fileitem.path)
                    selected_time = fileitem.last_commit.date
            default_model_name = selected_model_name

        with gr.Row():
            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Image')
                gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                gr_output = gr.Label(label='Prediction')

            gr_submit.click(
                self.predict_score,
                inputs=[
                    gr_input_image,
                    gr_model,
                ],
                outputs=[gr_output],
            )

    def launch_demo(self, default_model_name: Optional[str] = None,
                    server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
        """
        Launch the Gradio demo for the classifier model.

        This method creates a Gradio Blocks interface, sets up the UI components using make_ui(),
        and launches the demo server.

        :param default_model_name: The name of the default model to be selected in the dropdown.
        :type default_model_name: Optional[str]
        :param server_name: The name of the server to run the demo on. Defaults to None.
        :type server_name: Optional[str]
        :param server_port: The port number to run the demo on. Defaults to None.
        :type server_port: Optional[int]
        :param kwargs: Additional keyword arguments to pass to the Gradio launch method.

        :raises ImportError: If Gradio is not installed or properly configured.

        :Example:
        >>> model = ClassifyModel("username/repo_name")
        >>> model.launch_demo(default_model_name="model_v1", server_name="0.0.0.0", server_port=7860)
        """

        _check_gradio_env()
        with gr.Blocks() as demo:
            with gr.Row():
                with gr.Column():
                    repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
                    gr.HTML(f'<h2 style="text-align: center;">Classifier Demo For {self.repo_id}</h2>')
                    gr.Markdown(f'This is the quick demo for classifier model [{self.repo_id}]({repo_url}). '
                                f'Powered by `dghs-imgutils`\'s quick demo module.')

            with gr.Row():
                self.make_ui(default_model_name=default_model_name)

        demo.launch(
            server_name=server_name,
            server_port=server_port,
            **kwargs,
        )


@lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> ClassifyModel:
+180 −0
Original line number Diff line number Diff line
@@ -20,18 +20,61 @@ from typing import List, Optional, Tuple

import numpy as np
from PIL import Image
from hbutils.color import rnd_colors
from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download

from ..data import load_image, rgb_encode, ImageTyping
from ..utils import open_onnx_model

try:
    import gradio as gr
except (ImportError, ModuleNotFoundError):
    gr = None

__all__ = [
    'YOLOModel',
    'yolo_predict',
]


def _check_gradio_env():
    """
    Check if the Gradio library is installed and available.

    :raises EnvironmentError: If Gradio is not installed.
    """
    if gr is None:
        raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'
                               f'Please install it with `pip install dghs-imgutils[demo]`.')


def _v_fix(v):
    """
    Round and convert a float value to an integer.

    :param v: The float value to be rounded and converted.
    :type v: float
    :return: The rounded integer value.
    :rtype: int
    """
    return int(round(v))


def _bbox_fix(bbox):
    """
    Fix the bounding box coordinates by rounding them to integers.

    :param bbox: The bounding box coordinates.
    :type bbox: tuple
    :return: A tuple of fixed (rounded to integer) bounding box coordinates.
    :rtype: tuple
    """
    return tuple(map(_v_fix, bbox))


def _yolo_xywh2xyxy(x: np.ndarray) -> np.ndarray:
    """
    Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format.
@@ -403,9 +446,146 @@ class YOLOModel:
    def clear(self):
        """
        Clear cached model and metadata.

        This method removes all cached models and their associated metadata from memory.
        It's useful for freeing up memory or ensuring that the latest versions of models are loaded.
        """
        self._models.clear()

    def make_ui(self, default_model_name: Optional[str] = None,
                default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7):
        """
        Create a Gradio-based user interface for object detection.

        This method sets up an interactive UI that allows users to upload images,
        select models, and adjust detection parameters. It uses the Gradio library
        to create the interface.

        :param default_model_name: The name of the default model to use.
                                   If None, the most recently updated model is selected.
        :type default_model_name: Optional[str]
        :param default_conf_threshold: Default confidence threshold for the UI. Default is 0.25.
        :type default_conf_threshold: float
        :param default_iou_threshold: Default IoU threshold for the UI. Default is 0.7.
        :type default_iou_threshold: float

        :raises ImportError: If Gradio is not installed in the environment.

        :Example:

        >>> model = YOLOModel("username/repo_name")
        >>> model.make_ui(default_model_name="yolov5s")
        """
        _check_gradio_env()
        model_list = self.model_names
        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
            selected_model_name, selected_time = None, None
            for fileitem in hf_client.get_paths_info(
                    repo_id=self.repo_id,
                    repo_type='model',
                    paths=[f'{model_name}/model.onnx' for model_name in model_list],
                    expand=True,
            ):
                if not selected_time or fileitem.last_commit.date > selected_time:
                    selected_model_name = os.path.dirname(fileitem.path)
                    selected_time = fileitem.last_commit.date
            default_model_name = selected_model_name

        def _gr_detect(image: ImageTyping, model_name: str,
                       iou_threshold: float = 0.7, score_threshold: float = 0.25) \
                -> gr.AnnotatedImage:
            _, _, labels = self._open_model(model_name=model_name)
            _colors = list(map(str, rnd_colors(len(labels))))
            _color_map = dict(zip(labels, _colors))
            return gr.AnnotatedImage(
                value=(image, [
                    (_bbox_fix(bbox), label)
                    for bbox, label, _ in self.predict(
                        image=image,
                        model_name=model_name,
                        iou_threshold=iou_threshold,
                        conf_threshold=score_threshold,
                    )
                ]),
                color_map=_color_map,
                label='Labeled',
            )

        with gr.Row():
            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Image')
                gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                with gr.Row():
                    gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold')
                    gr_score_threshold = gr.Slider(0.0, 1.0, default_conf_threshold, label='Score Threshold')

                gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                gr_output_image = gr.AnnotatedImage(label="Labeled")

            gr_submit.click(
                _gr_detect,
                inputs=[
                    gr_input_image,
                    gr_model,
                    gr_iou_threshold,
                    gr_score_threshold,
                ],
                outputs=[gr_output_image],
            )

    def launch_demo(self, default_model_name: Optional[str] = None,
                    default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7,
                    server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
        """
        Launch a Gradio demo for object detection.

        This method creates and launches a Gradio demo that allows users to interactively
        perform object detection on uploaded images using the YOLO model.

        :param default_model_name: The name of the default model to use.
                                   If None, the most recently updated model is selected.
        :type default_model_name: Optional[str]
        :param default_conf_threshold: Default confidence threshold for the demo. Default is 0.25.
        :type default_conf_threshold: float
        :param default_iou_threshold: Default IoU threshold for the demo. Default is 0.7.
        :type default_iou_threshold: float
        :param server_name: The name of the server to run the demo on. Default is None.
        :type server_name: Optional[str]
        :param server_port: The port to run the demo on. Default is None.
        :type server_port: Optional[int]
        :param kwargs: Additional keyword arguments to pass to gr.Blocks.launch().

        :raises EnvironmentError: If Gradio is not installed in the environment.

        Example:
            >>> model = YOLOModel("username/repo_name")
            >>> model.launch_demo(default_model_name="yolov5s", server_name="0.0.0.0", server_port=7860)
        """
        _check_gradio_env()
        with gr.Blocks() as demo:
            with gr.Row():
                with gr.Column():
                    repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
                    gr.HTML(f'<h2 style="text-align: center;">YOLO Demo For {self.repo_id}</h2>')
                    gr.Markdown(f'This is the quick demo for YOLO model [{self.repo_id}]({repo_url}). '
                                f'Powered by `dghs-imgutils`\'s quick demo module.')

            with gr.Row():
                self.make_ui(
                    default_model_name=default_model_name,
                    default_conf_threshold=default_conf_threshold,
                    default_iou_threshold=default_iou_threshold,
                )

        demo.launch(
            server_name=server_name,
            server_port=server_port,
            **kwargs,
        )


@lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> YOLOModel:

requirements-demo.txt

0 → 100644
+1 −0
Original line number Diff line number Diff line
gradio>=4.44.0
 No newline at end of file