Commit 48de7f54 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add support for rtdetr detection model

parent 33771004
Loading
Loading
Loading
Loading
+18 −10
Original line number Diff line number Diff line
@@ -173,7 +173,7 @@ def _image_preprocess(image: Image.Image, max_infer_size: int = 1216, align: int
        - The preprocessed image
        - Original image dimensions (width, height)
        - New image dimensions (width, height)
    :rtype: tuple(Image.Image, tuple(int, int), tuple(int, int))
    :rtype: tuple(Image.Image, Tuple[int, int], Tuple[int, int])

    :Example:

@@ -194,7 +194,7 @@ def _image_preprocess(image: Image.Image, max_infer_size: int = 1216, align: int
    return image, (old_width, old_height), (new_width, new_height)


def _xy_postprocess(x, y, old_size, new_size):
def _xy_postprocess(x, y, old_size: Tuple[int, int], new_size: Tuple[int, int]):
    """
    Convert coordinates from the preprocessed image size back to the original image size.

@@ -203,12 +203,12 @@ def _xy_postprocess(x, y, old_size, new_size):
    :param y: Y-coordinate in the preprocessed image.
    :type y: float
    :param old_size: Original image dimensions (width, height).
    :type old_size: tuple(int, int)
    :type old_size: Tuple[int, int]
    :param new_size: Preprocessed image dimensions (width, height).
    :type new_size: tuple(int, int)
    :type new_size: Tuple[int, int]

    :return: Adjusted (x, y) coordinates for the original image size.
    :rtype: tuple(int, int)
    :rtype: Tuple[int, int]

    :Example:

@@ -223,7 +223,8 @@ def _xy_postprocess(x, y, old_size, new_size):
    return x, y


def _data_postprocess(output, conf_threshold, iou_threshold, old_size, new_size, labels: List[str]):
def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
                      old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]):
    """
    Post-process the raw output from the object detection model.

@@ -237,9 +238,9 @@ def _data_postprocess(output, conf_threshold, iou_threshold, old_size, new_size,
    :param iou_threshold: IoU threshold for non-maximum suppression.
    :type iou_threshold: float
    :param old_size: Original image dimensions (width, height).
    :type old_size: tuple(int, int)
    :type old_size: Tuple[int, int]
    :param new_size: Preprocessed image dimensions (width, height).
    :type new_size: tuple(int, int)
    :type new_size: Tuple[int, int]
    :param labels: List of class labels.
    :type labels: List[str]

@@ -249,7 +250,7 @@ def _data_postprocess(output, conf_threshold, iou_threshold, old_size, new_size,
    :Example:

    >>> output = np.array([[10, 10, 20, 20, 0.9, 0.1]])
    >>> _data_postprocess(output, 0.5, 0.5, (100, 100), (128, 128), ['cat', 'dog'])
    >>> _yolo_postprocess(output, 0.5, 0.5, (100, 100), (128, 128), ['cat', 'dog'])
    [((7, 7, 15, 15), 'cat', 0.9)]
    """
    if output.shape[-1] == 6:  # for end-to-end models like yolov10
@@ -453,7 +454,14 @@ class YOLOModel:
        new_image, old_size, new_size = _image_preprocess(image, max_infer_size)
        data = rgb_encode(new_image)[None, ...]
        output, = model.run(['output0'], {'images': data})
        return _data_postprocess(output[0], conf_threshold, iou_threshold, old_size, new_size, labels)
        return _yolo_postprocess(
            output=output[0],
            conf_threshold=conf_threshold,
            iou_threshold=iou_threshold,
            old_size=old_size,
            new_size=new_size,
            labels=labels
        )

    def clear(self):
        """