Commit c894ec5f authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): extract the post process formats

parent 48de7f54
Loading
Loading
Loading
Loading
+63 −32
Original line number Diff line number Diff line
@@ -223,37 +223,11 @@ def _xy_postprocess(x, y, old_size: Tuple[int, int], new_size: Tuple[int, int]):
    return x, y


def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
                      old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]):
    """
    Post-process the raw output from the object detection model.

    This function applies confidence thresholding, non-maximum suppression, and
    converts the coordinates back to the original image size.

    :param output: Raw output from the object detection model.
    :type output: np.ndarray
    :param conf_threshold: Confidence threshold for filtering detections.
    :type conf_threshold: float
    :param iou_threshold: IoU threshold for non-maximum suppression.
    :type iou_threshold: float
    :param old_size: Original image dimensions (width, height).
    :type old_size: Tuple[int, int]
    :param new_size: Preprocessed image dimensions (width, height).
    :type new_size: Tuple[int, int]
    :param labels: List of class labels.
    :type labels: List[str]

    :return: List of detections, each in the format ((x0, y0, x1, y1), label, confidence).
    :rtype: List[tuple(tuple(int, int, int, int), str, float)]

    :Example:

    >>> output = np.array([[10, 10, 20, 20, 0.9, 0.1]])
    >>> _yolo_postprocess(output, 0.5, 0.5, (100, 100), (128, 128), ['cat', 'dog'])
    [((7, 7, 15, 15), 'cat', 0.9)]
    """
    if output.shape[-1] == 6:  # for end-to-end models like yolov10
def _end2end_postprocess(output, conf_threshold: float, iou_threshold: float,
                         old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    assert output.shape[-1] == 6
    _ = iou_threshold  # actually the iou_threshold has not been supplied to end2end post-processing
    detections = []
    output = output[output[:, 4] > conf_threshold]
    selected_idx = _yolo_nms(output[:, :4], output[:, 4])
@@ -264,7 +238,13 @@ def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,

    return detections

    else:  # for nms-based models like yolov8

def _nms_postprocess(output, conf_threshold: float, iou_threshold: float,
                     old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    # the output should be like [4+cls, box_cnt]
    # cls means count of classes
    # box_cnt means count of bboxes
    max_scores = output[4:, :].max(axis=0)
    output = output[:, max_scores > conf_threshold].transpose(1, 0)
    boxes = output[:, :4]
@@ -288,6 +268,57 @@ def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
    return detections


def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
                      old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
    Post-process the raw output from the object detection model.

    This function applies confidence thresholding, non-maximum suppression, and
    converts the coordinates back to the original image size.

    :param output: Raw output from the object detection model.
    :type output: np.ndarray
    :param conf_threshold: Confidence threshold for filtering detections.
    :type conf_threshold: float
    :param iou_threshold: IoU threshold for non-maximum suppression.
    :type iou_threshold: float
    :param old_size: Original image dimensions (width, height).
    :type old_size: Tuple[int, int]
    :param new_size: Preprocessed image dimensions (width, height).
    :type new_size: Tuple[int, int]
    :param labels: List of class labels.
    :type labels: List[str]

    :return: List of detections, each in the format ((x0, y0, x1, y1), label, confidence).
    :rtype: List[tuple(tuple(int, int, int, int), str, float)]

    :Example:

    >>> output = np.array([[10, 10, 20, 20, 0.9, 0.1]])
    >>> _yolo_postprocess(output, 0.5, 0.5, (100, 100), (128, 128), ['cat', 'dog'])
    [((7, 7, 15, 15), 'cat', 0.9)]
    """
    if output.shape[-1] == 6:  # for end-to-end models like yolov10
        return _end2end_postprocess(
            output=output,
            conf_threshold=conf_threshold,
            iou_threshold=iou_threshold,
            old_size=old_size,
            new_size=new_size,
            labels=labels,
        )
    else:  # for nms-based models like yolov8
        return _nms_postprocess(
            output=output,
            conf_threshold=conf_threshold,
            iou_threshold=iou_threshold,
            old_size=old_size,
            new_size=new_size,
            labels=labels,
        )


def _safe_eval_names_str(names_str):
    """
    Safely evaluate the names string from model metadata.