Commit bc825934 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): defaultly disable the dynamic size of yolo models

parent 7df7bf5b
Loading
Loading
Loading
Loading
+35 −12
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@ import json
import math
import os
from threading import Lock
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np
from PIL import Image
@@ -155,7 +155,8 @@ def _yolo_nms(boxes, scores, iou_threshold: float = 0.7) -> List[int]:
    return keep


def _image_preprocess(image: Image.Image, max_infer_size: int = 1216, align: int = 32):
def _image_preprocess(image: Image.Image, max_infer_size: Union[int, Tuple[int, int]] = 1216,
                      allow_dynamic: bool = False, align: int = 32):
    """
    Preprocess an input image for inference.

@@ -166,6 +167,8 @@ def _image_preprocess(image: Image.Image, max_infer_size: int = 1216, align: int
    :type image: Image.Image
    :param max_infer_size: Maximum size (width or height) of the processed image. Default is 1216.
    :type max_infer_size: int
    :param allow_dynamic: If True, allows dynamic resizing of the image while maintaining aspect ratio. Default is False.
    :type allow_dynamic: bool
    :param align: Value to align the image dimensions to. Default is 32.
    :type align: int

@@ -183,13 +186,22 @@ def _image_preprocess(image: Image.Image, max_infer_size: int = 1216, align: int
    >>> print(old_size, new_size)
    (1000, 800) (1216, 992)
    """
    if isinstance(max_infer_size, int):
        max_infer_width, max_infer_height = max_infer_size, max_infer_size
    else:
        max_infer_width, max_infer_height = max_infer_size

    old_width, old_height = image.width, image.height
    new_width, new_height = old_width, old_height
    r = max_infer_size / max(new_width, new_height)
    if allow_dynamic:
        r = min(max_infer_width / new_width, max_infer_height / new_height)
        if r < 1:
            new_width, new_height = new_width * r, new_height * r
        new_width = int(math.ceil(new_width / align) * align)
        new_height = int(math.ceil(new_height / align) * align)
    else:
        new_width, new_height = max_infer_width, max_infer_height

    image = image.resize((new_width, new_height))
    return image, (old_width, old_height), (new_width, new_height)

@@ -539,7 +551,8 @@ class YOLOModel:
                ))
                model_metadata = model.get_modelmeta()
                if 'imgsz' in model_metadata.custom_metadata_map:
                    max_infer_size = max(json.loads(model_metadata.custom_metadata_map['imgsz']))
                    max_infer_size = tuple(json.loads(model_metadata.custom_metadata_map['imgsz']))
                    assert len(max_infer_size) == 2, f'imgsz should have 2 dims, but {max_infer_size!r} found.'
                else:
                    max_infer_size = 640
                names_map = _safe_eval_names_str(model_metadata.custom_metadata_map['names'])
@@ -567,7 +580,8 @@ class YOLOModel:
        return self._model_types[model_name]

    def predict(self, image: ImageTyping, model_name: str,
                conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
                conf_threshold: float = 0.25, iou_threshold: float = 0.7,
                allow_dynamic: bool = False) \
            -> List[Tuple[Tuple[int, int, int, int], str, float]]:
        """
        Perform object detection on an image using the specified YOLO model.
@@ -580,6 +594,9 @@ class YOLOModel:
        :type conf_threshold: float
        :param iou_threshold: IoU threshold for non-maximum suppression. Default is 0.7.
        :type iou_threshold: float
        :param allow_dynamic: If True, allows dynamic resizing of the image while maintaining aspect ratio.
                              Default is False.
        :type allow_dynamic: bool

        :return: List of detections, each in the format ((x0, y0, x1, y1), label, confidence).
        :rtype: List[Tuple[Tuple[int, int, int, int], str, float]]
@@ -594,8 +611,9 @@ class YOLOModel:
        """
        model, max_infer_size, labels = self._open_model(model_name)
        image = load_image(image, mode='RGB')
        new_image, old_size, new_size = _image_preprocess(image, max_infer_size)
        new_image, old_size, new_size = _image_preprocess(image, max_infer_size, allow_dynamic=allow_dynamic)
        data = rgb_encode(new_image)[None, ...]
        print(data.shape)
        output, = model.run(['output0'], {'images': data})
        model_type = self._get_model_type(model_name=model_name)
        if model_type == 'yolo':
@@ -669,7 +687,8 @@ class YOLOModel:
            default_model_name = selected_model_name

        def _gr_detect(image: ImageTyping, model_name: str,
                       iou_threshold: float = 0.7, score_threshold: float = 0.25) \
                       iou_threshold: float = 0.7, score_threshold: float = 0.25,
                       allow_dynamic: bool = False) \
                -> gr.AnnotatedImage:
            _, _, labels = self._open_model(model_name=model_name)
            _colors = list(map(str, rnd_colors(len(labels))))
@@ -682,6 +701,7 @@ class YOLOModel:
                        model_name=model_name,
                        iou_threshold=iou_threshold,
                        conf_threshold=score_threshold,
                        allow_dynamic=allow_dynamic,
                    )
                ]),
                color_map=_color_map,
@@ -691,7 +711,9 @@ class YOLOModel:
        with gr.Row():
            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                    gr_allow_dynamic = gr.Checkbox(value=False, label='Allow Dynamic Size')
                with gr.Row():
                    gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold')
                    gr_score_threshold = gr.Slider(0.0, 1.0, default_conf_threshold, label='Score Threshold')
@@ -708,6 +730,7 @@ class YOLOModel:
                    gr_model,
                    gr_iou_threshold,
                    gr_score_threshold,
                    gr_allow_dynamic,
                ],
                outputs=[gr_output_image],
            )