Commit dac2350d authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add detection similarity runnable code

parent 12a979e7
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -17,5 +17,6 @@ from .hand import detect_hands
from .head import detect_heads
from .nudenet import detect_with_nudenet
from .person import detect_person
from .similarity import calculate_iou, bboxes_similarity, detection_similarity
from .text import detect_text
from .visual import detection_visualize
+4 −0
Original line number Diff line number Diff line
from typing import Tuple

BBoxTyping = Tuple[float, float, float, float]
BBoxWithScoreAndLabel = Tuple[BBoxTyping, str, float]
+74 −0
Original line number Diff line number Diff line
from typing import List, Literal, Union

import numpy as np

from .base import BBoxTyping, BBoxWithScoreAndLabel


def calculate_iou(box1: BBoxTyping, box2: BBoxTyping):
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    intersection = max(0.0, x2 - x1) * max(0.0, y2 - y1)
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])

    iou = intersection / (area1 + area2 - intersection + 1e-6)
    return float(iou)


def bboxes_similarity(bboxes1: List[BBoxTyping], bboxes2: List[BBoxTyping],
                      mode: Literal['max', 'mean', 'raw'] = 'mean') -> Union[float, List[float]]:
    if len(bboxes1) != len(bboxes2):
        raise ValueError(f'Length of bboxes lists not match - {len(bboxes1)} vs {len(bboxes2)}.')

    n = len(bboxes1)
    iou_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            iou_matrix[i, j] = calculate_iou(bboxes1[i], bboxes2[j])

    # import here for faster launching speed
    from scipy.optimize import linear_sum_assignment
    row_ind, col_ind = linear_sum_assignment(-iou_matrix)
    print(iou_matrix)
    similarities = iou_matrix[row_ind, col_ind]
    if mode == 'max':
        return float(similarities.max())
    elif mode == 'mean':
        return float(similarities.mean())
    elif mode == 'raw':
        return similarities.tolist()
    else:
        raise ValueError(f'Unknown similarity mode for bboxes - {mode!r}.')


def detection_similarity(detect1: List[BBoxWithScoreAndLabel], detect2: List[BBoxWithScoreAndLabel],
                         mode: Literal['max', 'mean', 'raw'] = 'mean') -> Union[float, List[float]]:
    labels = sorted({*(l for _, l, _ in detect1), *(l for _, l, _ in detect2)})
    sims = []
    for current_label in labels:
        bboxes1 = [bbox for bbox, label, _ in detect1 if label == current_label]
        bboxes2 = [bbox for bbox, label, _ in detect2 if label == current_label]

        if len(bboxes1) != len(bboxes2):
            raise ValueError(f'Length of bboxes not match on label {current_label!r}'
                             f' - {len(bboxes1)} vs {len(bboxes2)}.')

        sims.extend(bboxes_similarity(
            bboxes1=bboxes1,
            bboxes2=bboxes2,
            mode='raw',
        ))

    sims = np.array(sims)
    if mode == 'max':
        return float(sims.max())
    elif mode == 'mean':
        return float(sims.mean())
    elif mode == 'raw':
        return sims.tolist()
    else:
        raise ValueError(f'Unknown similarity mode for bboxes - {mode!r}.')
+75 −0
Original line number Diff line number Diff line
import pytest

from imgutils.detect.similarity import calculate_iou, bboxes_similarity, detection_similarity


@pytest.fixture
def sample_bboxes():
    return [
        (0, 0, 10, 10),
        (5, 5, 15, 15),
        (20, 20, 30, 30),
    ]


@pytest.fixture
def sample_detections():
    return [
        ((0, 0, 10, 10), 'car', 0.9),
        ((5, 5, 15, 15), 'person', 0.8),
        ((20, 20, 30, 30), 'car', 0.7),
    ]


@pytest.mark.unittest
class TestBBoxFunctions:
    def test_calculate_iou(self):
        box1 = (0, 0, 10, 10)
        box2 = (5, 5, 15, 15)
        assert calculate_iou(box1, box2) == pytest.approx(25.0 / 175)

    def test_bboxes_similarity_max(self, sample_bboxes):
        result = bboxes_similarity(sample_bboxes, sample_bboxes, mode='max')
        assert isinstance(result, float)
        assert result == pytest.approx(1.0)

    def test_bboxes_similarity_mean(self, sample_bboxes):
        result = bboxes_similarity(sample_bboxes, sample_bboxes, mode='mean')
        assert isinstance(result, float)
        assert result == pytest.approx(1.0)

    def test_bboxes_similarity_raw(self, sample_bboxes):
        result = bboxes_similarity(sample_bboxes, sample_bboxes, mode='raw')
        assert isinstance(result, list)
        assert result == pytest.approx([1.0, 1.0, 1.0])

    def test_bboxes_similarity_invalid_mode(self, sample_bboxes):
        with pytest.raises(ValueError, match="Unknown similarity mode for bboxes - 'invalid'"):
            bboxes_similarity(sample_bboxes, sample_bboxes, mode='invalid')

    def test_bboxes_similarity_unequal_length(self, sample_bboxes):
        with pytest.raises(ValueError, match="Length of bboxes lists not match"):
            bboxes_similarity(sample_bboxes, sample_bboxes[:-1])

    def test_detection_similarity_max(self, sample_detections):
        result = detection_similarity(sample_detections, sample_detections, mode='max')
        assert isinstance(result, float)
        assert result == pytest.approx(1.0)

    def test_detection_similarity_mean(self, sample_detections):
        result = detection_similarity(sample_detections, sample_detections, mode='mean')
        assert isinstance(result, float)
        assert result == pytest.approx(1.0)

    def test_detection_similarity_raw(self, sample_detections):
        result = detection_similarity(sample_detections, sample_detections, mode='raw')
        assert isinstance(result, list)
        assert result == pytest.approx([1.0, 1.0, 1.0])

    def test_detection_similarity_invalid_mode(self, sample_detections):
        with pytest.raises(ValueError, match="Unknown similarity mode for bboxes - 'invalid'"):
            detection_similarity(sample_detections, sample_detections, mode='invalid')

    def test_detection_similarity_unequal_length(self, sample_detections):
        with pytest.raises(ValueError, match="Length of bboxes not match on label"):
            detection_similarity(sample_detections, sample_detections[:-1])