Commit e27e8e46 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add this option for yolo seg

parent de66e5d9
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -770,6 +770,8 @@ class YOLOModel:
        :type default_conf_threshold: Optional[float]
        :param default_iou_threshold: Default IoU threshold for the UI. Default is 0.7.
        :type default_iou_threshold: float
        :param apply_default_threshold: Enable score threshold auto-switch or not. Default is True.
        :type apply_default_threshold: bool

        :raises ImportError: If Gradio is not installed in the environment.
        :raises EnvironmentError: If in OFFLINE mode and no default_model_name is provided.
@@ -920,6 +922,8 @@ class YOLOModel:
        :type default_conf_threshold: Optional[float]
        :param default_iou_threshold: Default IoU threshold for the demo. Default is 0.7.
        :type default_iou_threshold: float
        :param apply_default_threshold: Enable score threshold auto-switch or not. Default is True.
        :type apply_default_threshold: bool
        :param server_name: The name of the server to run the demo on. Default is None.
        :type server_name: Optional[str]
        :param server_port: The port to run the demo on. Default is None.
+117 −13
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ post-processing segmentation results.
import json
import os
import threading
import warnings
from collections import defaultdict
from contextlib import contextmanager
from threading import Lock
@@ -244,6 +245,7 @@ class YOLOSegmentationModel:
        self._model_names = None
        self._models = {}
        self._model_types = {}
        self._model_thresholds = {}
        self._hf_token = hf_token
        self._global_lock = Lock()
        self._model_meta_lock = Lock()
@@ -371,8 +373,39 @@ class YOLOSegmentationModel:

        return self._model_types[model_name]

    def _get_default_thresholds(self, model_name: str) -> Optional[float]:
        """
        Get the default confidence threshold for the specified model.

        This method attempts to download and read the threshold.json file from the repository
        to determine the recommended confidence threshold for the model. If the file is not found,
        it returns None.

        :param model_name: Name of the model to get the default threshold for.
        :type model_name: str
        :return: The default confidence threshold, or None if not specified.
        :rtype: Optional[float]
        """
        with self._model_meta_lock:
            if model_name not in self._model_thresholds:
                try:
                    with open(hf_hub_download(
                            repo_id=self.repo_id,
                            repo_type='model',
                            filename=f'{model_name}/threshold.json',
                            revision='main',
                            token=self._get_hf_token()
                    )) as f:
                        default_threshold = json.load(f)['threshold']
                except (EntryNotFoundError,):
                    default_threshold = None

                self._model_thresholds[model_name] = default_threshold

        return self._model_thresholds[model_name]

    def predict(self, image: ImageTyping, model_name: str,
                conf_threshold: float = 0.25, iou_threshold: float = 0.7,
                conf_threshold: Optional[float] = None, iou_threshold: float = 0.7,
                allow_dynamic: bool = False) \
            -> List[Tuple[Tuple[int, int, int, int], str, float, np.ndarray]]:
        """
@@ -382,8 +415,10 @@ class YOLOSegmentationModel:
        :type image: ImageTyping
        :param model_name: Name of the model to use for prediction.
        :type model_name: str
        :param conf_threshold: Confidence threshold for filtering detections (0.0-1.0).
        :type conf_threshold: float
        :param conf_threshold: Confidence threshold for filtering detections.
                           Use recommended threshold when not assigned,
                           use 0.25 when no recommended threshold or in OFFLINE mode.
        :type conf_threshold: Optional[float]
        :param iou_threshold: IoU threshold for non-maximum suppression (0.0-1.0).
        :type iou_threshold: float
        :param allow_dynamic: Whether to allow dynamic resizing of the input image.
@@ -405,6 +440,12 @@ class YOLOSegmentationModel:
        >>> for bbox, label, confidence, mask in results:
        ...     print(f"Found {label} with confidence {confidence:.2f}")
        """
        if conf_threshold is None:
            # try to use default recommended threshold provided by model trainer
            conf_threshold = self._get_default_thresholds(model_name)
            if conf_threshold is None:
                conf_threshold = 0.25  # default threshold from YOLO official implement

        model, max_infer_size, labels, exec_lock = self._open_model(model_name)
        image = load_image(image, mode='RGB')
        new_image, old_size, new_size = _image_preprocess(image, max_infer_size, allow_dynamic=allow_dynamic)
@@ -435,9 +476,11 @@ class YOLOSegmentationModel:
        self._model_names = None
        self._models.clear()
        self._model_types.clear()
        self._model_thresholds.clear()

    def make_ui(self, default_model_name: Optional[str] = None,
                default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7):
                default_conf_threshold: Optional[float] = None, default_iou_threshold: float = 0.7,
                apply_default_threshold: bool = True):
        """
        Create a Gradio-based user interface for object detection.

@@ -448,10 +491,14 @@ class YOLOSegmentationModel:
        :param default_model_name: The name of the default model to use.
                                   If None, the most recently updated model is selected.
        :type default_model_name: Optional[str]
        :param default_conf_threshold: Default confidence threshold for the UI. Default is 0.25.
        :type default_conf_threshold: float
        :param default_conf_threshold: Default confidence threshold for the demo.
                                       Use recommended threshold when not assigned,
                                       use 0.25 when no recommended threshold or in OFFLINE mode.
        :type default_conf_threshold: Optional[float]
        :param default_iou_threshold: Default IoU threshold for the UI. Default is 0.7.
        :type default_iou_threshold: float
        :param apply_default_threshold: Enable score threshold auto-switch or not. Default is True.
        :type apply_default_threshold: bool

        :raises ImportError: If Gradio is not installed in the environment.
        :raises EnvironmentError: If in OFFLINE mode and no default_model_name is provided.
@@ -466,6 +513,9 @@ class YOLOSegmentationModel:
        if model_list is _OFFLINE and not default_model_name:
            raise EnvironmentError('You are in OFFLINE mode, '
                                   'you must assign a default model name to make this ui usable.')
        if model_list is _OFFLINE and default_conf_threshold is None:
            warnings.warn('You are in OFFLINE model, auto confidence threshold disabled, '
                          'will use 0.25 instead for non-local models.')

        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
@@ -504,6 +554,19 @@ class YOLOSegmentationModel:
            )

        with gr.Row():
            def _get_default_thresholds(model_name):
                v = self._get_default_thresholds(model_name)
                if v is None:
                    if model_list is _OFFLINE:
                        gr.Warning(f'No local default threshold found for model {model_name!r}.\n'
                                   f'Offline mode has been enabled.\n'
                                   f'Will use default threshold 0.25 instead.')
                    else:
                        gr.Warning(f'No default threshold found for model {model_name!r}.\n'
                                   f'Will use default threshold 0.25 instead.')
                    v = 0.25
                return v

            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
@@ -513,15 +576,48 @@ class YOLOSegmentationModel:
                        gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model',
                                               interactive=False)
                    gr_allow_dynamic = gr.Checkbox(value=False, label='Allow Dynamic Size')
                    if default_conf_threshold is None:
                        gr_auto_apply_threshold = gr.Checkbox(
                            value=apply_default_threshold,
                            label='Auto Use Default Score Threshold When Model Changed'
                        )
                    else:
                        gr_auto_apply_threshold = None

                with gr.Row():
                    gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold')
                    if default_conf_threshold is not None:
                        gr_score_threshold = gr.Slider(0.0, 1.0, default_conf_threshold, label='Score Threshold')
                    else:
                        gr_score_threshold = gr.Slider(0.0, 1.0, _get_default_thresholds(default_model_name),
                                                       label='Score Threshold')

                gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                gr_output_image = gr.AnnotatedImage(label="Labeled")

            if gr_auto_apply_threshold is not None:
                def _gr_model_change(model_name, need_apply):
                    if need_apply:
                        return gr.update(
                            value=_get_default_thresholds(model_name),
                        )
                    else:
                        return gr.update()

                gr_model.change(
                    _gr_model_change,
                    inputs=[gr_model, gr_auto_apply_threshold],
                    outputs=[gr_score_threshold],
                )

                gr_auto_apply_threshold.change(
                    _gr_model_change,
                    inputs=[gr_model, gr_auto_apply_threshold],
                    outputs=[gr_score_threshold],
                )

            gr_submit.click(
                _gr_detect,
                inputs=[
@@ -535,7 +631,8 @@ class YOLOSegmentationModel:
            )

    def launch_demo(self, default_model_name: Optional[str] = None,
                    default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7,
                    default_conf_threshold: Optional[float] = None, default_iou_threshold: float = 0.7,
                    apply_default_threshold: bool = True,
                    server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
        """
        Launch a Gradio demo for object detection.
@@ -546,10 +643,14 @@ class YOLOSegmentationModel:
        :param default_model_name: The name of the default model to use.
                                   If None, the most recently updated model is selected.
        :type default_model_name: Optional[str]
        :param default_conf_threshold: Default confidence threshold for the demo. Default is 0.25.
        :type default_conf_threshold: float
        :param default_conf_threshold: Default confidence threshold for the demo.
                                       Use recommended threshold when not assigned,
                                       use 0.25 when no recommended threshold or in OFFLINE mode.
        :type default_conf_threshold: Optional[float]
        :param default_iou_threshold: Default IoU threshold for the demo. Default is 0.7.
        :type default_iou_threshold: float
        :param apply_default_threshold: Enable score threshold auto-switch or not. Default is True.
        :type apply_default_threshold: bool
        :param server_name: The name of the server to run the demo on. Default is None.
        :type server_name: Optional[str]
        :param server_port: The port to run the demo on. Default is None.
@@ -578,6 +679,7 @@ class YOLOSegmentationModel:
                    default_model_name=default_model_name,
                    default_conf_threshold=default_conf_threshold,
                    default_iou_threshold=default_iou_threshold,
                    apply_default_threshold=apply_default_threshold,
                )

        demo.launch(
@@ -607,7 +709,7 @@ def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> YO


def yolo_seg_predict(image: ImageTyping, repo_id: str, model_name: str,
                     conf_threshold: float = 0.25, iou_threshold: float = 0.7,
                     conf_threshold: Optional[float] = None, iou_threshold: float = 0.7,
                     hf_token: Optional[str] = None, **kwargs) \
        -> List[Tuple[Tuple[int, int, int, int], str, float, np.ndarray]]:
    """
@@ -622,8 +724,10 @@ def yolo_seg_predict(image: ImageTyping, repo_id: str, model_name: str,
    :type repo_id: str
    :param model_name: Name of the specific model to use.
    :type model_name: str
    :param conf_threshold: Confidence threshold for filtering detections (0.0-1.0).
    :type conf_threshold: float
    :param conf_threshold: Confidence threshold for filtering detections.
                           Use recommended threshold when not assigned,
                           use 0.25 when no recommended threshold or in OFFLINE mode.
    :type conf_threshold: Optional[float]
    :param iou_threshold: IoU threshold for non-maximum suppression (0.0-1.0).
    :type iou_threshold: float
    :param hf_token: Hugging Face API token for accessing private repositories.