Commit 2e77452e authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add support for rtdetr models

parent c894ec5f
Loading
Loading
Loading
Loading
+65 −17
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ from typing import List, Optional, Tuple
import numpy as np
from PIL import Image
from hbutils.color import rnd_colors
from hfutils.operate import get_hf_client
from hfutils.operate import get_hf_client, get_hf_fs
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download
@@ -194,7 +194,7 @@ def _image_preprocess(image: Image.Image, max_infer_size: int = 1216, align: int
    return image, (old_width, old_height), (new_width, new_height)


def _xy_postprocess(x, y, old_size: Tuple[int, int], new_size: Tuple[int, int]):
def _xy_postprocess(x, y, old_size: Tuple[float, float], new_size: Tuple[float, float]):
    """
    Convert coordinates from the preprocessed image size back to the original image size.

@@ -203,9 +203,9 @@ def _xy_postprocess(x, y, old_size: Tuple[int, int], new_size: Tuple[int, int]):
    :param y: Y-coordinate in the preprocessed image.
    :type y: float
    :param old_size: Original image dimensions (width, height).
    :type old_size: Tuple[int, int]
    :type old_size: Tuple[float, float]
    :param new_size: Preprocessed image dimensions (width, height).
    :type new_size: Tuple[int, int]
    :type new_size: Tuple[float, float]

    :return: Adjusted (x, y) coordinates for the original image size.
    :rtype: Tuple[int, int]
@@ -224,7 +224,7 @@ def _xy_postprocess(x, y, old_size: Tuple[int, int], new_size: Tuple[int, int]):


def _end2end_postprocess(output, conf_threshold: float, iou_threshold: float,
                         old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
                         old_size: Tuple[float, float], new_size: Tuple[float, float], labels: List[str]) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    assert output.shape[-1] == 6
    _ = iou_threshold  # actually the iou_threshold has not been supplied to end2end post-processing
@@ -240,8 +240,9 @@ def _end2end_postprocess(output, conf_threshold: float, iou_threshold: float,


def _nms_postprocess(output, conf_threshold: float, iou_threshold: float,
                     old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
                     old_size: Tuple[float, float], new_size: Tuple[float, float], labels: List[str]) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    assert output.shape[0] == 4 + len(labels)
    # the output should be like [4+cls, box_cnt]
    # cls means count of classes
    # box_cnt means count of bboxes
@@ -269,7 +270,7 @@ def _nms_postprocess(output, conf_threshold: float, iou_threshold: float,


def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
                      old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
                      old_size: Tuple[float, float], new_size: Tuple[float, float], labels: List[str]) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
    Post-process the raw output from the object detection model.
@@ -284,9 +285,9 @@ def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
    :param iou_threshold: IoU threshold for non-maximum suppression.
    :type iou_threshold: float
    :param old_size: Original image dimensions (width, height).
    :type old_size: Tuple[int, int]
    :type old_size: Tuple[float, float]
    :param new_size: Preprocessed image dimensions (width, height).
    :type new_size: Tuple[int, int]
    :type new_size: Tuple[float, float]
    :param labels: List of class labels.
    :type labels: List[str]

@@ -319,6 +320,22 @@ def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
        )


def _rtdetr_postprocess(output, conf_threshold: float, iou_threshold: float,
                        old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    assert output.shape[-1] == 4 + len(labels)
    # the size rtdetr using is [0.0, 1.0]
    _ = new_size
    return _nms_postprocess(
        output=output.transpose(1, 0),
        conf_threshold=conf_threshold,
        iou_threshold=iou_threshold,
        old_size=old_size,
        new_size=(1.0, 1.0),
        labels=labels,
    )


def _safe_eval_names_str(names_str):
    """
    Safely evaluate the names string from model metadata.
@@ -383,6 +400,7 @@ class YOLOModel:
        self.repo_id = repo_id
        self._model_names = None
        self._models = {}
        self._model_types = {}
        self._hf_token = hf_token

    def _get_hf_token(self) -> Optional[str]:
@@ -454,6 +472,23 @@ class YOLOModel:

        return self._models[model_name]

    def _get_model_type(self, model_name: str):
        if model_name not in self._model_types:
            hf_fs = get_hf_fs(hf_token=self._get_hf_token())
            fs_path = hf_fs_path(
                repo_id=self.repo_id,
                repo_type='model',
                filename=f'{model_name}/model_type.json',
                revision='main',
            )
            if hf_fs.exists(fs_path):
                model_type = json.loads(hf_fs.read_text(fs_path))['model_type']
            else:
                model_type = 'yolo'
            self._model_types[model_name] = model_type

        return self._model_types[model_name]

    def predict(self, image: ImageTyping, model_name: str,
                conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
            -> List[Tuple[Tuple[int, int, int, int], str, float]]:
@@ -485,6 +520,8 @@ class YOLOModel:
        new_image, old_size, new_size = _image_preprocess(image, max_infer_size)
        data = rgb_encode(new_image)[None, ...]
        output, = model.run(['output0'], {'images': data})
        model_type = self._get_model_type(model_name=model_name)
        if model_type == 'yolo':
            return _yolo_postprocess(
                output=output[0],
                conf_threshold=conf_threshold,
@@ -493,6 +530,17 @@ class YOLOModel:
                new_size=new_size,
                labels=labels
            )
        elif model_type == 'rtdetr':
            return _rtdetr_postprocess(
                output=output[0],
                conf_threshold=conf_threshold,
                iou_threshold=iou_threshold,
                old_size=old_size,
                new_size=new_size,
                labels=labels
            )
        else:
            raise ValueError(f'Unknown object detection model type - {model_type!r}.')  # pragma: no cover

    def clear(self):
        """