Unverified Commit 3bfa48d4 authored by Naomi Rue Golding's avatar Naomi Rue Golding Committed by GitHub
Browse files

Merge pull request #165 from deepghs/dev/ythreshold

dev(narugo): add extra default threshold system for YOLOModel and YOLOSegmentationModel
parents f3d7c6f2 e27e8e46
Loading
Loading
Loading
Loading
+118 −14
Original line number Diff line number Diff line
@@ -13,14 +13,15 @@ The module supports various image input types and allows customization of confid

import ast
import json
import math
import os
import threading
import warnings
from collections import defaultdict
from contextlib import contextmanager
from threading import Lock
from typing import List, Optional, Tuple, Union

import math
import numpy as np
import requests
from PIL import Image
@@ -515,6 +516,7 @@ class YOLOModel:
        self._model_names = None
        self._models = {}
        self._model_types = {}
        self._model_thresholds = {}
        self._hf_token = hf_token
        self._global_lock = Lock()
        self._model_meta_lock = Lock()
@@ -634,8 +636,39 @@ class YOLOModel:

        return self._model_types[model_name]

    def _get_default_thresholds(self, model_name: str) -> Optional[float]:
        """
        Get the default confidence threshold for the specified model.

        This method attempts to download and read the threshold.json file from the repository
        to determine the recommended confidence threshold for the model. If the file is not found,
        it returns None.

        :param model_name: Name of the model to get the default threshold for.
        :type model_name: str
        :return: The default confidence threshold, or None if not specified.
        :rtype: Optional[float]
        """
        with self._model_meta_lock:
            if model_name not in self._model_thresholds:
                try:
                    with open(hf_hub_download(
                            repo_id=self.repo_id,
                            repo_type='model',
                            filename=f'{model_name}/threshold.json',
                            revision='main',
                            token=self._get_hf_token()
                    )) as f:
                        default_threshold = json.load(f)['threshold']
                except (EntryNotFoundError,):
                    default_threshold = None

                self._model_thresholds[model_name] = default_threshold

        return self._model_thresholds[model_name]

    def predict(self, image: ImageTyping, model_name: str,
                conf_threshold: float = 0.25, iou_threshold: float = 0.7,
                conf_threshold: Optional[float] = None, iou_threshold: float = 0.7,
                allow_dynamic: bool = False) \
            -> List[Tuple[Tuple[int, int, int, int], str, float]]:
        """
@@ -645,8 +678,10 @@ class YOLOModel:
        :type image: ImageTyping
        :param model_name: Name of the YOLO model to use.
        :type model_name: str
        :param conf_threshold: Confidence threshold for filtering detections. Default is 0.25.
        :type conf_threshold: float
        :param conf_threshold: Confidence threshold for filtering detections.
                               Use recommended threshold when not assigned,
                               use 0.25 when no recommended threshold or in OFFLINE mode.
        :type conf_threshold: Optional[float]
        :param iou_threshold: IoU threshold for non-maximum suppression. Default is 0.7.
        :type iou_threshold: float
        :param allow_dynamic: If True, allows dynamic resizing of the image while maintaining aspect ratio.
@@ -664,6 +699,12 @@ class YOLOModel:
        >>> print(detections[0])  # First detection
        ((100, 200, 300, 400), 'person', 0.95)
        """
        if conf_threshold is None:
            # try to use default recommended threshold provided by model trainer
            conf_threshold = self._get_default_thresholds(model_name)
            if conf_threshold is None:
                conf_threshold = 0.25  # default threshold from YOLO official implement

        model, max_infer_size, labels, exec_lock = self._open_model(model_name)
        image = load_image(image, mode='RGB')
        new_image, old_size, new_size = _image_preprocess(image, max_infer_size, allow_dynamic=allow_dynamic)
@@ -708,9 +749,11 @@ class YOLOModel:
        self._model_names = None
        self._models.clear()
        self._model_types.clear()
        self._model_thresholds.clear()

    def make_ui(self, default_model_name: Optional[str] = None,
                default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7):
                default_conf_threshold: Optional[float] = None, default_iou_threshold: float = 0.7,
                apply_default_threshold: bool = True):
        """
        Create a Gradio-based user interface for object detection.

@@ -721,10 +764,14 @@ class YOLOModel:
        :param default_model_name: The name of the default model to use.
                                   If None, the most recently updated model is selected.
        :type default_model_name: Optional[str]
        :param default_conf_threshold: Default confidence threshold for the UI. Default is 0.25.
        :type default_conf_threshold: float
        :param default_conf_threshold: Default confidence threshold for the demo.
                                       Use recommended threshold when not assigned,
                                       use 0.25 when no recommended threshold or in OFFLINE mode.
        :type default_conf_threshold: Optional[float]
        :param default_iou_threshold: Default IoU threshold for the UI. Default is 0.7.
        :type default_iou_threshold: float
        :param apply_default_threshold: Enable score threshold auto-switch or not. Default is True.
        :type apply_default_threshold: bool

        :raises ImportError: If Gradio is not installed in the environment.
        :raises EnvironmentError: If in OFFLINE mode and no default_model_name is provided.
@@ -739,6 +786,9 @@ class YOLOModel:
        if model_list is _OFFLINE and not default_model_name:
            raise EnvironmentError('You are in OFFLINE mode, '
                                   'you must assign a default model name to make this ui usable.')
        if model_list is _OFFLINE and default_conf_threshold is None:
            warnings.warn('You are in OFFLINE model, auto confidence threshold disabled, '
                          'will use 0.25 instead for non-local models.')

        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
@@ -777,6 +827,19 @@ class YOLOModel:
            )

        with gr.Row():
            def _get_default_thresholds(model_name):
                v = self._get_default_thresholds(model_name)
                if v is None:
                    if model_list is _OFFLINE:
                        gr.Warning(f'No local default threshold found for model {model_name!r}.\n'
                                   f'Offline mode has been enabled.\n'
                                   f'Will use default threshold 0.25 instead.')
                    else:
                        gr.Warning(f'No default threshold found for model {model_name!r}.\n'
                                   f'Will use default threshold 0.25 instead.')
                    v = 0.25
                return v

            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
@@ -786,15 +849,48 @@ class YOLOModel:
                        gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model',
                                               interactive=False)
                    gr_allow_dynamic = gr.Checkbox(value=False, label='Allow Dynamic Size')
                    if default_conf_threshold is None:
                        gr_auto_apply_threshold = gr.Checkbox(
                            value=apply_default_threshold,
                            label='Auto Use Default Score Threshold When Model Changed'
                        )
                    else:
                        gr_auto_apply_threshold = None

                with gr.Row():
                    gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold')
                    if default_conf_threshold is not None:
                        gr_score_threshold = gr.Slider(0.0, 1.0, default_conf_threshold, label='Score Threshold')
                    else:
                        gr_score_threshold = gr.Slider(0.0, 1.0, _get_default_thresholds(default_model_name),
                                                       label='Score Threshold')

                gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                gr_output_image = gr.AnnotatedImage(label="Labeled")

            if gr_auto_apply_threshold is not None:
                def _gr_model_change(model_name, need_apply):
                    if need_apply:
                        return gr.update(
                            value=_get_default_thresholds(model_name),
                        )
                    else:
                        return gr.update()

                gr_model.change(
                    _gr_model_change,
                    inputs=[gr_model, gr_auto_apply_threshold],
                    outputs=[gr_score_threshold],
                )

                gr_auto_apply_threshold.change(
                    _gr_model_change,
                    inputs=[gr_model, gr_auto_apply_threshold],
                    outputs=[gr_score_threshold],
                )

            gr_submit.click(
                _gr_detect,
                inputs=[
@@ -808,7 +904,8 @@ class YOLOModel:
            )

    def launch_demo(self, default_model_name: Optional[str] = None,
                    default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7,
                    default_conf_threshold: Optional[float] = None, default_iou_threshold: float = 0.7,
                    apply_default_threshold: bool = True,
                    server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
        """
        Launch a Gradio demo for object detection.
@@ -819,10 +916,14 @@ class YOLOModel:
        :param default_model_name: The name of the default model to use.
                                   If None, the most recently updated model is selected.
        :type default_model_name: Optional[str]
        :param default_conf_threshold: Default confidence threshold for the demo. Default is 0.25.
        :type default_conf_threshold: float
        :param default_conf_threshold: Default confidence threshold for the demo.
                                       Use recommended threshold when not assigned,
                                       use 0.25 when no recommended threshold or in OFFLINE mode.
        :type default_conf_threshold: Optional[float]
        :param default_iou_threshold: Default IoU threshold for the demo. Default is 0.7.
        :type default_iou_threshold: float
        :param apply_default_threshold: Enable score threshold auto-switch or not. Default is True.
        :type apply_default_threshold: bool
        :param server_name: The name of the server to run the demo on. Default is None.
        :type server_name: Optional[str]
        :param server_port: The port to run the demo on. Default is None.
@@ -850,6 +951,7 @@ class YOLOModel:
                    default_model_name=default_model_name,
                    default_conf_threshold=default_conf_threshold,
                    default_iou_threshold=default_iou_threshold,
                    apply_default_threshold=apply_default_threshold,
                )

        demo.launch(
@@ -886,7 +988,7 @@ def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> YO


def yolo_predict(image: ImageTyping, repo_id: str, model_name: str,
                 conf_threshold: float = 0.25, iou_threshold: float = 0.7,
                 conf_threshold: Optional[float] = None, iou_threshold: float = 0.7,
                 hf_token: Optional[str] = None, **kwargs) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
@@ -901,8 +1003,10 @@ def yolo_predict(image: ImageTyping, repo_id: str, model_name: str,
    :type repo_id: str
    :param model_name: Name of the YOLO model to use.
    :type model_name: str
    :param conf_threshold: Confidence threshold for filtering detections. Default is 0.25.
    :type conf_threshold: float
    :param conf_threshold: Confidence threshold for filtering detections.
                           Use recommended threshold when not assigned,
                           use 0.25 when no recommended threshold or in OFFLINE mode.
    :type conf_threshold: Optional[float]
    :param iou_threshold: IoU threshold for non-maximum suppression. Default is 0.7.
    :type iou_threshold: float
    :param hf_token: Optional Hugging Face authentication token.
+117 −13
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@ post-processing segmentation results.
import json
import os
import threading
import warnings
from collections import defaultdict
from contextlib import contextmanager
from threading import Lock
@@ -244,6 +245,7 @@ class YOLOSegmentationModel:
        self._model_names = None
        self._models = {}
        self._model_types = {}
        self._model_thresholds = {}
        self._hf_token = hf_token
        self._global_lock = Lock()
        self._model_meta_lock = Lock()
@@ -371,8 +373,39 @@ class YOLOSegmentationModel:

        return self._model_types[model_name]

    def _get_default_thresholds(self, model_name: str) -> Optional[float]:
        """
        Get the default confidence threshold for the specified model.

        This method attempts to download and read the threshold.json file from the repository
        to determine the recommended confidence threshold for the model. If the file is not found,
        it returns None.

        :param model_name: Name of the model to get the default threshold for.
        :type model_name: str
        :return: The default confidence threshold, or None if not specified.
        :rtype: Optional[float]
        """
        with self._model_meta_lock:
            if model_name not in self._model_thresholds:
                try:
                    with open(hf_hub_download(
                            repo_id=self.repo_id,
                            repo_type='model',
                            filename=f'{model_name}/threshold.json',
                            revision='main',
                            token=self._get_hf_token()
                    )) as f:
                        default_threshold = json.load(f)['threshold']
                except (EntryNotFoundError,):
                    default_threshold = None

                self._model_thresholds[model_name] = default_threshold

        return self._model_thresholds[model_name]

    def predict(self, image: ImageTyping, model_name: str,
                conf_threshold: float = 0.25, iou_threshold: float = 0.7,
                conf_threshold: Optional[float] = None, iou_threshold: float = 0.7,
                allow_dynamic: bool = False) \
            -> List[Tuple[Tuple[int, int, int, int], str, float, np.ndarray]]:
        """
@@ -382,8 +415,10 @@ class YOLOSegmentationModel:
        :type image: ImageTyping
        :param model_name: Name of the model to use for prediction.
        :type model_name: str
        :param conf_threshold: Confidence threshold for filtering detections (0.0-1.0).
        :type conf_threshold: float
        :param conf_threshold: Confidence threshold for filtering detections.
                           Use recommended threshold when not assigned,
                           use 0.25 when no recommended threshold or in OFFLINE mode.
        :type conf_threshold: Optional[float]
        :param iou_threshold: IoU threshold for non-maximum suppression (0.0-1.0).
        :type iou_threshold: float
        :param allow_dynamic: Whether to allow dynamic resizing of the input image.
@@ -405,6 +440,12 @@ class YOLOSegmentationModel:
        >>> for bbox, label, confidence, mask in results:
        ...     print(f"Found {label} with confidence {confidence:.2f}")
        """
        if conf_threshold is None:
            # try to use default recommended threshold provided by model trainer
            conf_threshold = self._get_default_thresholds(model_name)
            if conf_threshold is None:
                conf_threshold = 0.25  # default threshold from YOLO official implement

        model, max_infer_size, labels, exec_lock = self._open_model(model_name)
        image = load_image(image, mode='RGB')
        new_image, old_size, new_size = _image_preprocess(image, max_infer_size, allow_dynamic=allow_dynamic)
@@ -435,9 +476,11 @@ class YOLOSegmentationModel:
        self._model_names = None
        self._models.clear()
        self._model_types.clear()
        self._model_thresholds.clear()

    def make_ui(self, default_model_name: Optional[str] = None,
                default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7):
                default_conf_threshold: Optional[float] = None, default_iou_threshold: float = 0.7,
                apply_default_threshold: bool = True):
        """
        Create a Gradio-based user interface for object detection.

@@ -448,10 +491,14 @@ class YOLOSegmentationModel:
        :param default_model_name: The name of the default model to use.
                                   If None, the most recently updated model is selected.
        :type default_model_name: Optional[str]
        :param default_conf_threshold: Default confidence threshold for the UI. Default is 0.25.
        :type default_conf_threshold: float
        :param default_conf_threshold: Default confidence threshold for the demo.
                                       Use recommended threshold when not assigned,
                                       use 0.25 when no recommended threshold or in OFFLINE mode.
        :type default_conf_threshold: Optional[float]
        :param default_iou_threshold: Default IoU threshold for the UI. Default is 0.7.
        :type default_iou_threshold: float
        :param apply_default_threshold: Enable score threshold auto-switch or not. Default is True.
        :type apply_default_threshold: bool

        :raises ImportError: If Gradio is not installed in the environment.
        :raises EnvironmentError: If in OFFLINE mode and no default_model_name is provided.
@@ -466,6 +513,9 @@ class YOLOSegmentationModel:
        if model_list is _OFFLINE and not default_model_name:
            raise EnvironmentError('You are in OFFLINE mode, '
                                   'you must assign a default model name to make this ui usable.')
        if model_list is _OFFLINE and default_conf_threshold is None:
            warnings.warn('You are in OFFLINE model, auto confidence threshold disabled, '
                          'will use 0.25 instead for non-local models.')

        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
@@ -504,6 +554,19 @@ class YOLOSegmentationModel:
            )

        with gr.Row():
            def _get_default_thresholds(model_name):
                v = self._get_default_thresholds(model_name)
                if v is None:
                    if model_list is _OFFLINE:
                        gr.Warning(f'No local default threshold found for model {model_name!r}.\n'
                                   f'Offline mode has been enabled.\n'
                                   f'Will use default threshold 0.25 instead.')
                    else:
                        gr.Warning(f'No default threshold found for model {model_name!r}.\n'
                                   f'Will use default threshold 0.25 instead.')
                    v = 0.25
                return v

            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
@@ -513,15 +576,48 @@ class YOLOSegmentationModel:
                        gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model',
                                               interactive=False)
                    gr_allow_dynamic = gr.Checkbox(value=False, label='Allow Dynamic Size')
                    if default_conf_threshold is None:
                        gr_auto_apply_threshold = gr.Checkbox(
                            value=apply_default_threshold,
                            label='Auto Use Default Score Threshold When Model Changed'
                        )
                    else:
                        gr_auto_apply_threshold = None

                with gr.Row():
                    gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold')
                    if default_conf_threshold is not None:
                        gr_score_threshold = gr.Slider(0.0, 1.0, default_conf_threshold, label='Score Threshold')
                    else:
                        gr_score_threshold = gr.Slider(0.0, 1.0, _get_default_thresholds(default_model_name),
                                                       label='Score Threshold')

                gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                gr_output_image = gr.AnnotatedImage(label="Labeled")

            if gr_auto_apply_threshold is not None:
                def _gr_model_change(model_name, need_apply):
                    if need_apply:
                        return gr.update(
                            value=_get_default_thresholds(model_name),
                        )
                    else:
                        return gr.update()

                gr_model.change(
                    _gr_model_change,
                    inputs=[gr_model, gr_auto_apply_threshold],
                    outputs=[gr_score_threshold],
                )

                gr_auto_apply_threshold.change(
                    _gr_model_change,
                    inputs=[gr_model, gr_auto_apply_threshold],
                    outputs=[gr_score_threshold],
                )

            gr_submit.click(
                _gr_detect,
                inputs=[
@@ -535,7 +631,8 @@ class YOLOSegmentationModel:
            )

    def launch_demo(self, default_model_name: Optional[str] = None,
                    default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7,
                    default_conf_threshold: Optional[float] = None, default_iou_threshold: float = 0.7,
                    apply_default_threshold: bool = True,
                    server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
        """
        Launch a Gradio demo for object detection.
@@ -546,10 +643,14 @@ class YOLOSegmentationModel:
        :param default_model_name: The name of the default model to use.
                                   If None, the most recently updated model is selected.
        :type default_model_name: Optional[str]
        :param default_conf_threshold: Default confidence threshold for the demo. Default is 0.25.
        :type default_conf_threshold: float
        :param default_conf_threshold: Default confidence threshold for the demo.
                                       Use recommended threshold when not assigned,
                                       use 0.25 when no recommended threshold or in OFFLINE mode.
        :type default_conf_threshold: Optional[float]
        :param default_iou_threshold: Default IoU threshold for the demo. Default is 0.7.
        :type default_iou_threshold: float
        :param apply_default_threshold: Enable score threshold auto-switch or not. Default is True.
        :type apply_default_threshold: bool
        :param server_name: The name of the server to run the demo on. Default is None.
        :type server_name: Optional[str]
        :param server_port: The port to run the demo on. Default is None.
@@ -578,6 +679,7 @@ class YOLOSegmentationModel:
                    default_model_name=default_model_name,
                    default_conf_threshold=default_conf_threshold,
                    default_iou_threshold=default_iou_threshold,
                    apply_default_threshold=apply_default_threshold,
                )

        demo.launch(
@@ -607,7 +709,7 @@ def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> YO


def yolo_seg_predict(image: ImageTyping, repo_id: str, model_name: str,
                     conf_threshold: float = 0.25, iou_threshold: float = 0.7,
                     conf_threshold: Optional[float] = None, iou_threshold: float = 0.7,
                     hf_token: Optional[str] = None, **kwargs) \
        -> List[Tuple[Tuple[int, int, int, int], str, float, np.ndarray]]:
    """
@@ -622,8 +724,10 @@ def yolo_seg_predict(image: ImageTyping, repo_id: str, model_name: str,
    :type repo_id: str
    :param model_name: Name of the specific model to use.
    :type model_name: str
    :param conf_threshold: Confidence threshold for filtering detections (0.0-1.0).
    :type conf_threshold: float
    :param conf_threshold: Confidence threshold for filtering detections.
                           Use recommended threshold when not assigned,
                           use 0.25 when no recommended threshold or in OFFLINE mode.
    :type conf_threshold: Optional[float]
    :param iou_threshold: IoU threshold for non-maximum suppression (0.0-1.0).
    :type iou_threshold: float
    :param hf_token: Hugging Face API token for accessing private repositories.