Unverified Commit 5f14d6af authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #110 from deepghs/dev/rtdetr

dev(narugo): add support for rtdetr detection model
parents 33771004 233b2426
Loading
Loading
Loading
Loading
+196 −37
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ from typing import List, Optional, Tuple
import numpy as np
from PIL import Image
from hbutils.color import rnd_colors
from hfutils.operate import get_hf_client
from hfutils.operate import get_hf_client, get_hf_fs
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download
@@ -173,7 +173,7 @@ def _image_preprocess(image: Image.Image, max_infer_size: int = 1216, align: int
        - The preprocessed image
        - Original image dimensions (width, height)
        - New image dimensions (width, height)
    :rtype: tuple(Image.Image, tuple(int, int), tuple(int, int))
    :rtype: tuple(Image.Image, Tuple[int, int], Tuple[int, int])

    :Example:

@@ -194,7 +194,7 @@ def _image_preprocess(image: Image.Image, max_infer_size: int = 1216, align: int
    return image, (old_width, old_height), (new_width, new_height)


def _xy_postprocess(x, y, old_size, new_size):
def _xy_postprocess(x, y, old_size: Tuple[float, float], new_size: Tuple[float, float]):
    """
    Convert coordinates from the preprocessed image size back to the original image size.

@@ -203,12 +203,12 @@ def _xy_postprocess(x, y, old_size, new_size):
    :param y: Y-coordinate in the preprocessed image.
    :type y: float
    :param old_size: Original image dimensions (width, height).
    :type old_size: tuple(int, int)
    :type old_size: Tuple[float, float]
    :param new_size: Preprocessed image dimensions (width, height).
    :type new_size: tuple(int, int)
    :type new_size: Tuple[float, float]

    :return: Adjusted (x, y) coordinates for the original image size.
    :rtype: tuple(int, int)
    :rtype: Tuple[int, int]

    :Example:

@@ -223,36 +223,35 @@ def _xy_postprocess(x, y, old_size, new_size):
    return x, y


def _data_postprocess(output, conf_threshold, iou_threshold, old_size, new_size, labels: List[str]):
def _end2end_postprocess(output, conf_threshold: float, iou_threshold: float,
                         old_size: Tuple[float, float], new_size: Tuple[float, float], labels: List[str]) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
    Post-process the raw output from the object detection model.
    Post-process the output of an end-to-end object detection model.

    This function applies confidence thresholding, non-maximum suppression, and
    converts the coordinates back to the original image size.
    This function filters detections based on confidence, applies non-maximum suppression,
    and transforms coordinates back to the original image size.

    :param output: Raw output from the object detection model.
    :param output: Raw output from the end-to-end object detection model.
    :type output: np.ndarray
    :param conf_threshold: Confidence threshold for filtering detections.
    :type conf_threshold: float
    :param iou_threshold: IoU threshold for non-maximum suppression.
    :param iou_threshold: IoU threshold for non-maximum suppression (not used in this function).
    :type iou_threshold: float
    :param old_size: Original image dimensions (width, height).
    :type old_size: tuple(int, int)
    :type old_size: Tuple[float, float]
    :param new_size: Preprocessed image dimensions (width, height).
    :type new_size: tuple(int, int)
    :type new_size: Tuple[float, float]
    :param labels: List of class labels.
    :type labels: List[str]

    :return: List of detections, each in the format ((x0, y0, x1, y1), label, confidence).
    :rtype: List[tuple(tuple(int, int, int, int), str, float)]

    :Example:
    :rtype: List[Tuple[Tuple[int, int, int, int], str, float]]

    >>> output = np.array([[10, 10, 20, 20, 0.9, 0.1]])
    >>> _data_postprocess(output, 0.5, 0.5, (100, 100), (128, 128), ['cat', 'dog'])
    [((7, 7, 15, 15), 'cat', 0.9)]
    :raises AssertionError: If the output shape is not as expected.
    """
    if output.shape[-1] == 6:  # for end-to-end models like yolov10
    assert output.shape[-1] == 6
    _ = iou_threshold  # actually the iou_threshold has not been supplied to end2end post-processing
    detections = []
    output = output[output[:, 4] > conf_threshold]
    selected_idx = _yolo_nms(output[:, :4], output[:, 4])
@@ -263,7 +262,38 @@ def _data_postprocess(output, conf_threshold, iou_threshold, old_size, new_size,

    return detections

    else:  # for nms-based models like yolov8

def _nms_postprocess(output, conf_threshold: float, iou_threshold: float,
                     old_size: Tuple[float, float], new_size: Tuple[float, float], labels: List[str]) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
    Post-process the output of an NMS-based object detection model.

    This function applies confidence thresholding, non-maximum suppression,
    and transforms coordinates back to the original image size.

    :param output: Raw output from the NMS-based object detection model.
    :type output: np.ndarray
    :param conf_threshold: Confidence threshold for filtering detections.
    :type conf_threshold: float
    :param iou_threshold: IoU threshold for non-maximum suppression.
    :type iou_threshold: float
    :param old_size: Original image dimensions (width, height).
    :type old_size: Tuple[float, float]
    :param new_size: Preprocessed image dimensions (width, height).
    :type new_size: Tuple[float, float]
    :param labels: List of class labels.
    :type labels: List[str]

    :return: List of detections, each in the format ((x0, y0, x1, y1), label, confidence).
    :rtype: List[Tuple[Tuple[int, int, int, int], str, float]]

    :raises AssertionError: If the output shape is not as expected.
    """
    assert output.shape[0] == 4 + len(labels)
    # the output should be like [4+cls, box_cnt]
    # cls means count of classes
    # box_cnt means count of bboxes
    max_scores = output[4:, :].max(axis=0)
    output = output[:, max_scores > conf_threshold].transpose(1, 0)
    boxes = output[:, :4]
@@ -287,6 +317,97 @@ def _data_postprocess(output, conf_threshold, iou_threshold, old_size, new_size,
    return detections


def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
                      old_size: Tuple[float, float], new_size: Tuple[float, float], labels: List[str]) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
    Post-process the raw output from the object detection model.

    This function applies confidence thresholding, non-maximum suppression, and
    converts the coordinates back to the original image size.

    :param output: Raw output from the object detection model.
    :type output: np.ndarray
    :param conf_threshold: Confidence threshold for filtering detections.
    :type conf_threshold: float
    :param iou_threshold: IoU threshold for non-maximum suppression.
    :type iou_threshold: float
    :param old_size: Original image dimensions (width, height).
    :type old_size: Tuple[float, float]
    :param new_size: Preprocessed image dimensions (width, height).
    :type new_size: Tuple[float, float]
    :param labels: List of class labels.
    :type labels: List[str]

    :return: List of detections, each in the format ((x0, y0, x1, y1), label, confidence).
    :rtype: List[tuple(tuple(int, int, int, int), str, float)]

    :Example:

    >>> output = np.array([[10, 10, 20, 20, 0.9, 0.1]])
    >>> _yolo_postprocess(output, 0.5, 0.5, (100, 100), (128, 128), ['cat', 'dog'])
    [((7, 7, 15, 15), 'cat', 0.9)]
    """
    if output.shape[-1] == 6:  # for end-to-end models like yolov10
        return _end2end_postprocess(
            output=output,
            conf_threshold=conf_threshold,
            iou_threshold=iou_threshold,
            old_size=old_size,
            new_size=new_size,
            labels=labels,
        )
    else:  # for nms-based models like yolov8
        return _nms_postprocess(
            output=output,
            conf_threshold=conf_threshold,
            iou_threshold=iou_threshold,
            old_size=old_size,
            new_size=new_size,
            labels=labels,
        )


def _rtdetr_postprocess(output, conf_threshold: float, iou_threshold: float,
                        old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
    Post-process the output from an RT-DETR (Real-Time DEtection TRansformer) model.

    This function handles the specific output format of RT-DETR models and applies
    the necessary post-processing steps.

    :param output: Raw output from the RT-DETR model.
    :type output: np.ndarray
    :param conf_threshold: Confidence threshold for filtering detections.
    :type conf_threshold: float
    :param iou_threshold: IoU threshold for non-maximum suppression.
    :type iou_threshold: float
    :param old_size: Original image dimensions (width, height).
    :type old_size: Tuple[int, int]
    :param new_size: Preprocessed image dimensions (width, height) (not used in this function).
    :type new_size: Tuple[int, int]
    :param labels: List of class labels.
    :type labels: List[str]

    :return: List of detections, each in the format ((x0, y0, x1, y1), label, confidence).
    :rtype: List[Tuple[Tuple[int, int, int, int], str, float]]

    :raises AssertionError: If the output shape is not as expected.
    """
    assert output.shape[-1] == 4 + len(labels)
    # the size rtdetr using is [0.0, 1.0]
    _ = new_size
    return _nms_postprocess(
        output=output.transpose(1, 0),
        conf_threshold=conf_threshold,
        iou_threshold=iou_threshold,
        old_size=old_size,
        new_size=(1.0, 1.0),
        labels=labels,
    )


def _safe_eval_names_str(names_str):
    """
    Safely evaluate the names string from model metadata.
@@ -351,6 +472,7 @@ class YOLOModel:
        self.repo_id = repo_id
        self._model_names = None
        self._models = {}
        self._model_types = {}
        self._hf_token = hf_token

    def _get_hf_token(self) -> Optional[str]:
@@ -422,6 +544,23 @@ class YOLOModel:

        return self._models[model_name]

    def _get_model_type(self, model_name: str):
        if model_name not in self._model_types:
            hf_fs = get_hf_fs(hf_token=self._get_hf_token())
            fs_path = hf_fs_path(
                repo_id=self.repo_id,
                repo_type='model',
                filename=f'{model_name}/model_type.json',
                revision='main',
            )
            if hf_fs.exists(fs_path):
                model_type = json.loads(hf_fs.read_text(fs_path))['model_type']
            else:
                model_type = 'yolo'
            self._model_types[model_name] = model_type

        return self._model_types[model_name]

    def predict(self, image: ImageTyping, model_name: str,
                conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
            -> List[Tuple[Tuple[int, int, int, int], str, float]]:
@@ -453,7 +592,27 @@ class YOLOModel:
        new_image, old_size, new_size = _image_preprocess(image, max_infer_size)
        data = rgb_encode(new_image)[None, ...]
        output, = model.run(['output0'], {'images': data})
        return _data_postprocess(output[0], conf_threshold, iou_threshold, old_size, new_size, labels)
        model_type = self._get_model_type(model_name=model_name)
        if model_type == 'yolo':
            return _yolo_postprocess(
                output=output[0],
                conf_threshold=conf_threshold,
                iou_threshold=iou_threshold,
                old_size=old_size,
                new_size=new_size,
                labels=labels
            )
        elif model_type == 'rtdetr':
            return _rtdetr_postprocess(
                output=output[0],
                conf_threshold=conf_threshold,
                iou_threshold=iou_threshold,
                old_size=old_size,
                new_size=new_size,
                labels=labels
            )
        else:
            raise ValueError(f'Unknown object detection model type - {model_type!r}.')  # pragma: no cover

    def clear(self):
        """
+13 −0
Original line number Diff line number Diff line
@@ -51,3 +51,16 @@ class TestDetectHead:
            ((461, 247, 536, 330), 'head', 0.434),
        ])
        assert similarity >= 0.85

    @pytest.mark.parametrize(['model_name'], [
        ('head_detect_v1.6_l_rtdetr',),
    ])
    def test_detect_with_rtdetr(self, model_name: str):
        # ATTENTION: results of rtdetr models are really shitty and unstable
        #            so this expected result is 100% bullshit
        #            just make sure the rtdetr models can be properly inferred
        detections = detect_heads(get_testfile('genshin_post.jpg'), model_name=model_name)
        similarity = detection_similarity(detections, [
            ((780, 9, 1125, 208), 'head', 0.3077814280986786)
        ])
        assert similarity >= 0.85