Commit 8a4098af authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add gradio demo for yolo

parent 75084615
Loading
Loading
Loading
Loading
+87 −0
Original line number Diff line number Diff line
@@ -20,18 +20,39 @@ from typing import List, Optional, Tuple

import numpy as np
from PIL import Image
from hbutils.color import rnd_colors
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download
from natsort import natsorted

from ..data import load_image, rgb_encode, ImageTyping
from ..utils import open_onnx_model

try:
    import gradio as gr
except (ImportError, ModuleNotFoundError):
    gr = None

__all__ = [
    'YOLOModel',
    'yolo_predict',
]


def _check_gradio_env():
    if gr is None:
        raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'
                               f'Please install it with `pip install dghs-imgutils[demo]`.')


def _v_fix(v):
    return int(round(v))


def _bbox_fix(bbox):
    return tuple(map(_v_fix, bbox))


def _yolo_xywh2xyxy(x: np.ndarray) -> np.ndarray:
    """
    Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format.
@@ -406,6 +427,72 @@ class YOLOModel:
        """
        self._models.clear()

    def make_ui(self, default_model_name: Optional[str] = None,
                default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7):
        _check_gradio_env()
        model_list = self.model_names
        default_model_name = default_model_name or natsorted(self.model_names)[-1]

        def _gr_detect(image: ImageTyping, model_name: str,
                       iou_threshold: float = 0.7, score_threshold: float = 0.25) \
                -> gr.AnnotatedImage:
            _, _, labels = self._open_model(model_name=model_name)
            _colors = list(map(str, rnd_colors(len(labels))))
            _color_map = dict(zip(labels, _colors))
            return gr.AnnotatedImage(
                value=(image, [
                    (_bbox_fix(bbox), label)
                    for bbox, label, _ in self.predict(
                        image=image,
                        model_name=model_name,
                        iou_threshold=iou_threshold,
                        conf_threshold=score_threshold,
                    )
                ]),
                color_map=_color_map,
                label='Labeled',
            )

        with gr.Row():
            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Image')
                gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                with gr.Row():
                    gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold')
                    gr_score_threshold = gr.Slider(0.0, 1.0, default_conf_threshold, label='Score Threshold')

                gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                gr_output_image = gr.AnnotatedImage(label="Labeled")

            gr_submit.click(
                _gr_detect,
                inputs=[
                    gr_input_image,
                    gr_model,
                    gr_iou_threshold,
                    gr_score_threshold,
                ],
                outputs=[gr_output_image],
            )

    def launch_demo(self, default_model_name: Optional[str] = None,
                    default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7,
                    server_port: int = 7860, **kwargs):
        _check_gradio_env()
        with gr.Blocks() as demo:
            self.make_ui(
                default_model_name=default_model_name,
                default_conf_threshold=default_conf_threshold,
                default_iou_threshold=default_iou_threshold,
            )

        demo.launch(
            server_port=server_port,
            **kwargs,
        )


@lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> YOLOModel:

requirements-demo.txt

0 → 100644
+1 −0
Original line number Diff line number Diff line
gradio>=4.44.0
 No newline at end of file