Commit 12a979e7 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add iou check for end-to-end models

parent 7baaafbf
Loading
Loading
Loading
Loading
+8 −6
Original line number Diff line number Diff line
@@ -103,7 +103,7 @@ def _yolo_xywh2xyxy(x: np.ndarray) -> np.ndarray:
    return y


def _yolo_nms(boxes, scores, thresh: float = 0.7) -> List[int]:
def _yolo_nms(boxes, scores, iou_threshold: float = 0.7) -> List[int]:
    """
    Perform Non-Maximum Suppression (NMS) on bounding boxes.

@@ -113,8 +113,8 @@ def _yolo_nms(boxes, scores, thresh: float = 0.7) -> List[int]:
    :type boxes: np.ndarray
    :param scores: Array of confidence scores for each bounding box.
    :type scores: np.ndarray
    :param thresh: IoU threshold for considering boxes as overlapping. Default is 0.7.
    :type thresh: float
    :param iou_threshold: IoU threshold for considering boxes as overlapping. Default is 0.7.
    :type iou_threshold: float

    :return: List of indices of the boxes to keep after NMS.
    :rtype: List[int]
@@ -149,7 +149,7 @@ def _yolo_nms(boxes, scores, thresh: float = 0.7) -> List[int]:
        inter = w * h
        iou = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(iou <= thresh)[0]
        inds = np.where(iou <= iou_threshold)[0]
        order = order[inds + 1]

    return keep
@@ -254,7 +254,9 @@ def _data_postprocess(output, conf_threshold, iou_threshold, old_size, new_size,
    """
    if output.shape[-1] == 6:  # for end-to-end models like yolov10
        detections = []
        for x0, y0, x1, y1, score, cls in output[output[:, 4] > conf_threshold]:
        output = output[output[:, 4] > conf_threshold]
        selected_idx = _yolo_nms(output[:, :4], output[:, 4])
        for x0, y0, x1, y1, score, cls in output[selected_idx]:
            x0, y0 = _xy_postprocess(x0, y0, old_size, new_size)
            x1, y1 = _xy_postprocess(x1, y1, old_size, new_size)
            detections.append(((x0, y0, x1, y1), labels[int(cls.item())], float(score)))
@@ -272,7 +274,7 @@ def _data_postprocess(output, conf_threshold, iou_threshold, old_size, new_size,
            return []

        boxes = _yolo_xywh2xyxy(boxes)
        idx = _yolo_nms(boxes, filtered_max_scores, thresh=iou_threshold)
        idx = _yolo_nms(boxes, filtered_max_scores, iou_threshold=iou_threshold)
        boxes, scores = boxes[idx], scores[idx]

        detections = []