Commit 7baaafbf authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add support for yolov10

parent 8c8c208a
Loading
Loading
Loading
Loading
+31 −21
Original line number Diff line number Diff line
@@ -252,6 +252,16 @@ def _data_postprocess(output, conf_threshold, iou_threshold, old_size, new_size,
    >>> _data_postprocess(output, 0.5, 0.5, (100, 100), (128, 128), ['cat', 'dog'])
    [((7, 7, 15, 15), 'cat', 0.9)]
    """
    if output.shape[-1] == 6:  # for end-to-end models like yolov10
        detections = []
        for x0, y0, x1, y1, score, cls in output[output[:, 4] > conf_threshold]:
            x0, y0 = _xy_postprocess(x0, y0, old_size, new_size)
            x1, y1 = _xy_postprocess(x1, y1, old_size, new_size)
            detections.append(((x0, y0, x1, y1), labels[int(cls.item())], float(score)))

        return detections

    else:  # for nms-based models like yolov8
        max_scores = output[4:, :].max(axis=0)
        output = output[:, max_scores > conf_threshold].transpose(1, 0)
        boxes = output[:, :4]