Commit 9cc75b68 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): partial save

parent 88cedae3
Loading
Loading
Loading
Loading
+33 −15
Original line number Diff line number Diff line
@@ -92,6 +92,19 @@ def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
    return data.astype(np.float32)


def _labels_scores_to_topk(labels: np.ndarray, scores: np.ndarray, topk: Optional[int] = 20):
    if topk and topk < labels.shape[-1]:
        indices = np.argpartition(scores, -topk)[-topk:]
        indices = indices[np.argsort(-scores[indices], kind='mergesort')]
    else:
        indices = np.argsort(-scores, kind='mergesort')
    labels, scores = labels[indices], scores[indices]

    # noinspection PyTypeChecker
    values = dict(zip(labels.tolist(), scores.tolist()))
    return values


ImagePreprocessFunc = Callable[[Image.Image], Image.Image]


@@ -286,7 +299,7 @@ class ClassifyModel:
        :type model_name: str

        :return: Raw model output
        :rtype: np.ndarray
        :rtype: Dict[str, np.ndarray]

        :raises RuntimeError: If model input shape is incompatible
        """
@@ -308,8 +321,10 @@ class ClassifyModel:
                input_ = _img_encode(image, size=(width, height))[None, ...]
            else:
                input_ = _img_encode(image)[None, ...]
        output, = self._open_model(model_name).run(['output'], {'input': input_})
        return output
        onnx_model = self._open_model(model_name)
        output_names = [output.name for output in onnx_model.get_outputs()]
        output_values = self._open_model(model_name).run(output_names, {'input': input_})
        return {name: value for name, value in zip(output_names, output_values)}

    def predict_score(self, image: ImageTyping, model_name: str,
                      label_group: str = 'default', topk: Optional[int] = 20) -> Dict[str, float]:
@@ -333,19 +348,14 @@ class ClassifyModel:
        :raises ValueError: If the model name is invalid.
        :raises RuntimeError: If there's an error during prediction.
        """
        output = self._raw_predict(image, model_name)
        output = self._raw_predict(image, model_name)['output']
        labels = self._open_label(model_name)[label_group]
        scores = output[0]
        if topk and topk < labels.shape[-1]:
            indices = np.argpartition(scores, -topk)[-topk:]
            indices = indices[np.argsort(-scores[indices], kind='mergesort')]
        else:
            indices = np.argsort(-scores, kind='mergesort')
        labels, scores = labels[indices], scores[indices]

        # noinspection PyTypeChecker
        values = dict(zip(labels.tolist(), scores.tolist()))
        return values
        return _labels_scores_to_topk(
            labels=labels,
            scores=scores,
            topk=topk,
        )

    def predict(self, image: ImageTyping, model_name: str, label_group: str = 'default') -> Tuple[str, float]:
        """
@@ -366,10 +376,17 @@ class ClassifyModel:
        :raises ValueError: If the model name is invalid.
        :raises RuntimeError: If there's an error during prediction.
        """
        output = self._raw_predict(image, model_name)[0]
        output = self._raw_predict(image, model_name)['output'][0]
        max_id = np.argmax(output)
        return self._open_label(model_name)[label_group][max_id], output[max_id].item()

    def predict_fmt(self, image: ImageTyping, model_name: str,
                      label_group: str = 'default', topk: Optional[int] = 20):
        d_data = {name: value[0] for name, value in self._raw_predict(image, model_name).items()}
        scores = d_data['output']



    def clear(self):
        """
        Clear the cached models and labels.
@@ -378,6 +395,7 @@ class ClassifyModel:
        """
        self._models.clear()
        self._labels.clear()
        self._preprocesses.clear()

    def make_ui(self, default_model_name: Optional[str] = None):
        """
+27 −0
Original line number Diff line number Diff line
from typing import List

__all__ = [
    'vreplace',
    'vnames',
]


@@ -24,3 +27,27 @@ def vreplace(v, mapping):
            return v
        else:
            return mapping.get(v, v)


def _v_iternames(v):
    if isinstance(v, (list, tuple)):
        for item in v:
            yield from _v_iternames(item)
    elif isinstance(v, dict):
        for _, item in v.items():
            yield from _v_iternames(item)
    else:
        try:
            _ = hash(v)
        except TypeError:  # pragma: no cover
            pass
        else:
            yield v


def vnames(v, str_only: bool = True) -> List[str]:
    name_set = set()
    for name in _v_iternames(v):
        if not str_only or isinstance(name, str):
            name_set.add(name)
    return list(name_set)
+9 −0
Original line number Diff line number Diff line
@@ -2,9 +2,18 @@ import pytest
from PIL import Image

from imgutils.generic import classify_predict_score
from imgutils.generic.classify import _open_models_for_repo_id
from test.testings import get_testfile


@pytest.fixture(scope='module', autouse=True)
def _release_model_after_run():
    try:
        yield
    finally:
        _open_models_for_repo_id('deepghs/timms_mobilenet').clear()


@pytest.mark.unittest
class TestGenericClassify:
    def test_classify_predict_score(self):
+74 −0
Original line number Diff line number Diff line
import pytest

from imgutils.utils import vnames, vreplace


@pytest.fixture
def sample_list():
    return [1, 2, 3, "a", "b"]


@pytest.fixture
def sample_dict():
    return {"x": 1, "y": "a", "z": [2, "b", {"w": 3}]}


@pytest.fixture
def sample_mapping():
    return {1: "one", "a": "A", 3: "three"}


@pytest.mark.unittest
class TestVReplaceFunctions:
    def test_vreplace_list(self, sample_list, sample_mapping):
        result = vreplace(sample_list, sample_mapping)
        assert result == ["one", 2, "three", "A", "b"]
        assert isinstance(result, list)

    def test_vreplace_tuple(self, sample_mapping):
        input_tuple = (1, "a", 3)
        result = vreplace(input_tuple, sample_mapping)
        assert result == ("one", "A", "three")
        assert isinstance(result, tuple)

    def test_vreplace_dict(self, sample_dict, sample_mapping):
        result = vreplace(sample_dict, sample_mapping)
        assert result == {"x": "one", "y": "A", "z": [2, "b", {"w": "three"}]}
        assert isinstance(result, dict)

    def test_vreplace_scalar(self, sample_mapping):
        assert vreplace(1, sample_mapping) == "one"
        assert vreplace(4, sample_mapping) == 4  # unmapped value

    def test_vreplace_empty_mapping(self, sample_list):
        result = vreplace(sample_list, {})
        assert result == sample_list

    def test_vnames_list(self, sample_list):
        result = vnames(sample_list)
        assert set(result) == {"a", "b"}

        result_all = vnames(sample_list, str_only=False)
        assert set(result_all) == {1, 2, 3, "a", "b"}

    def test_vnames_dict(self, sample_dict):
        result = vnames(sample_dict)
        assert set(result) == {"a", "b"}

        result_all = vnames(sample_dict, str_only=False)
        assert set(result_all) == {1, 2, 3, "a", "b"}

    def test_vnames_empty(self):
        assert set(vnames([])) == set()
        assert set(vnames({})) == set()
        assert set(vnames([], str_only=False)) == set()

    def test_vnames_nested(self):
        nested = [1, ["a", 2, ["b", 3]], {"x": "c"}]
        assert set(vnames(nested)) == {"a", "b", "c"}
        assert set(vnames(nested, str_only=False)) == {1, 2, 3, "a", "b", "c"}

    def test_vnames_mixed_types(self):
        mixed = [1.5, True, None, "test"]
        assert set(vnames(mixed)) == {"test"}
        assert set(vnames(mixed, str_only=False)) == {1.5, True, None, "test"}