Unverified Commit 77e9f90a authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #133 from deepghs/dev/cls

dev(narugo): add formatted classifier prediction model
parents 88cedae3 e7f48bdb
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ ClassifyModel
-----------------------------------------

.. autoclass:: ClassifyModel
    :members: __init__, predict_score, predict, clear, make_ui, launch_demo
    :members: __init__, predict_score, predict, predict_fmt, clear, make_ui, launch_demo



@@ -29,3 +29,10 @@ classify_predict



classify_predict_fmt
-----------------------------------------

.. autofunction:: classify_predict_fmt


+118 −16
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ It also handles token-based authentication for accessing private Hugging Face re

import json
import os
import re
from threading import Lock
from typing import Tuple, Optional, List, Dict, Callable

@@ -25,7 +26,7 @@ from huggingface_hub.errors import EntryNotFoundError

from ..data import rgb_encode, ImageTyping, load_image
from ..preprocess import create_pillow_transforms
from ..utils import open_onnx_model, ts_lru_cache
from ..utils import open_onnx_model, ts_lru_cache, vnames, vreplace

try:
    import gradio as gr
@@ -36,6 +37,7 @@ __all__ = [
    'ClassifyModel',
    'classify_predict_score',
    'classify_predict',
    'classify_predict_fmt',
]


@@ -92,6 +94,19 @@ def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
    return data.astype(np.float32)


def _labels_scores_to_topk(labels: np.ndarray, scores: np.ndarray, topk: Optional[int] = 20):
    if topk and topk < labels.shape[-1]:
        indices = np.argpartition(scores, -topk)[-topk:]
        indices = indices[np.argsort(-scores[indices], kind='mergesort')]
    else:
        indices = np.argsort(-scores, kind='mergesort')
    labels, scores = labels[indices], scores[indices]

    # noinspection PyTypeChecker
    values = dict(zip(labels.tolist(), scores.tolist()))
    return values


ImagePreprocessFunc = Callable[[Image.Image], Image.Image]


@@ -286,7 +301,7 @@ class ClassifyModel:
        :type model_name: str

        :return: Raw model output
        :rtype: np.ndarray
        :rtype: Dict[str, np.ndarray]

        :raises RuntimeError: If model input shape is incompatible
        """
@@ -308,8 +323,10 @@ class ClassifyModel:
                input_ = _img_encode(image, size=(width, height))[None, ...]
            else:
                input_ = _img_encode(image)[None, ...]
        output, = self._open_model(model_name).run(['output'], {'input': input_})
        return output
        onnx_model = self._open_model(model_name)
        output_names = [output.name for output in onnx_model.get_outputs()]
        output_values = self._open_model(model_name).run(output_names, {'input': input_})
        return {name: value for name, value in zip(output_names, output_values)}

    def predict_score(self, image: ImageTyping, model_name: str,
                      label_group: str = 'default', topk: Optional[int] = 20) -> Dict[str, float]:
@@ -333,19 +350,14 @@ class ClassifyModel:
        :raises ValueError: If the model name is invalid.
        :raises RuntimeError: If there's an error during prediction.
        """
        output = self._raw_predict(image, model_name)
        output = self._raw_predict(image, model_name)['output']
        labels = self._open_label(model_name)[label_group]
        scores = output[0]
        if topk and topk < labels.shape[-1]:
            indices = np.argpartition(scores, -topk)[-topk:]
            indices = indices[np.argsort(-scores[indices], kind='mergesort')]
        else:
            indices = np.argsort(-scores, kind='mergesort')
        labels, scores = labels[indices], scores[indices]

        # noinspection PyTypeChecker
        values = dict(zip(labels.tolist(), scores.tolist()))
        return values
        return _labels_scores_to_topk(
            labels=labels,
            scores=scores,
            topk=topk,
        )

    def predict(self, image: ImageTyping, model_name: str, label_group: str = 'default') -> Tuple[str, float]:
        """
@@ -366,10 +378,59 @@ class ClassifyModel:
        :raises ValueError: If the model name is invalid.
        :raises RuntimeError: If there's an error during prediction.
        """
        output = self._raw_predict(image, model_name)[0]
        output = self._raw_predict(image, model_name)['output'][0]
        max_id = np.argmax(output)
        return self._open_label(model_name)[label_group][max_id], output[max_id].item()

    def predict_fmt(self, image: ImageTyping, model_name: str, fmt='scores-top5'):
        """
        Predict the scores for each class with given format specification using the specified model.

        :param image: The input image to classify.
        :type image: ImageTyping
        :param model_name: The name of the model to use for prediction.
        :type model_name: str
        :param fmt: Format specification. Default is ``scores-top5``.

        :return: Prediction result formatted with parameter ``fmt``.

        :raises ValueError: If the model name is invalid.
        :raises RuntimeError: If there's an error during prediction.

        .. note::
            The following specifications are supported in parameter ``fmt``:

            - ``output``, raw prediction result, in np.ndarray format.
            - ``logits``, (not available in some models) logits result, in np.ndarray format.
            - ``embedding``, (not available in some models) embeddings result, in np.ndarray format.
            - ``scores``, prediction scores of all classes in dict format.
            - ``scores-topK``, prediction scores of top-K classes in dict format, e.g. ``scores-top10`` means top 10 scores.
            - ``scores-<label_group>``, prediction scores of all classes with label group ``<label_group>``, e.g. ``scores-descriptions`` means all scores with ``descriptions`` label group.
            - ``scores-topK-<label_group>``, prediction scores of top-K classes with label group ``<label_group>``.
        """
        d_data = {name: value[0] for name, value in self._raw_predict(image, model_name).items()}
        scores = d_data['output']
        d_labels = self._open_label(model_name)
        vname_to_spair = {}
        d_scores = {}
        for vname in vnames(fmt, str_only=True):
            matching = re.fullmatch(r'^scores(-top(?P<topk>\d+))?(-(?P<label_group>[a-zA-Z\d_]+))?$', vname)
            if matching:
                topk = int(matching.group('topk')) if matching.group('topk') else None
                label_group = matching.group('label_group') if matching.group('label_group') else 'default'
                vname_to_spair[vname] = (topk, label_group)
                if (topk, label_group) not in d_scores:
                    d_scores[(topk, label_group)] = _labels_scores_to_topk(
                        labels=d_labels[label_group],
                        scores=scores,
                        topk=topk,
                    )

        return vreplace(fmt, mapping={
            **d_data,
            **{vname: d_scores[vpair] for vname, vpair in vname_to_spair.items()},
        })

    def clear(self):
        """
        Clear the cached models and labels.
@@ -378,6 +439,7 @@ class ClassifyModel:
        """
        self._models.clear()
        self._labels.clear()
        self._preprocesses.clear()

    def make_ui(self, default_model_name: Optional[str] = None):
        """
@@ -579,3 +641,43 @@ def classify_predict(image: ImageTyping, repo_id: str, model_name: str, label_gr
        model_name=model_name,
        label_group=label_group,
    )


def classify_predict_fmt(image: ImageTyping, repo_id: str, model_name: str, fmt='scores-top5',
                         hf_token: Optional[str] = None):
    """
    Predict the scores for each class with given format specification using the specified model.

    This function is a convenience wrapper around ClassifyModel's predict method.

    :param image: The input image to classify.
    :type image: ImageTyping
    :param repo_id: The repository ID containing the models.
    :type repo_id: str
    :param model_name: The name of the model to use for prediction.
    :type model_name: str
    :param fmt: Format specification. Default is ``scores-top5``.
    :param hf_token: Optional Hugging Face authentication token.
    :type hf_token: Optional[str]

    :return: Prediction result formatted with parameter ``fmt``.

    :raises ValueError: If the model name is invalid.
    :raises RuntimeError: If there's an error during prediction.

    .. note::
        The following specifications are supported in parameter ``fmt``:

        - ``output``, raw prediction result, in np.ndarray format.
        - ``logits``, (not available in some models) logits result, in np.ndarray format.
        - ``embedding``, (not available in some models) embeddings result, in np.ndarray format.
        - ``scores``, prediction scores of all classes in dict format.
        - ``scores-topK``, prediction scores of top-K classes in dict format, e.g. ``scores-top10`` means top 10 scores.
        - ``scores-<label_group>``, prediction scores of all classes with label group ``<label_group>``, e.g. ``scores-descriptions`` means all scores with ``descriptions`` label group.
        - ``scores-topK-<label_group>``, prediction scores of top-K classes with label group ``<label_group>``.
    """
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict_fmt(
        image=image,
        model_name=model_name,
        fmt=fmt,
    )
+82 −5
Original line number Diff line number Diff line
"""
This module provides utilities for manipulating nested data structures, particularly focusing on
value replacement and name extraction from complex nested structures like lists, tuples, and dictionaries.

The module offers functionality to recursively traverse nested data structures and either replace values
based on a mapping or extract unique names/values from the structure.
"""

from typing import List

__all__ = [
    'vreplace',
    'vnames',
]


def vreplace(v, mapping):
    """
    Replaces values in a data structure using a mapping dictionary.
    :param v: The input data structure.
    Recursively replaces values in a nested data structure using a mapping dictionary.

    This function traverses through nested lists, tuples, and dictionaries, replacing values
    according to the provided mapping. If a value exists as a key in the mapping dictionary,
    it will be replaced with its corresponding value.

    :param v: The input data structure to process (can be a nested structure of lists, tuples, dicts)
    :type v: Any
    :param mapping: A dictionary mapping values to replacement values.
    :type mapping: Dict
    :return: The modified data structure.
    :param mapping: A dictionary defining value replacements
    :type mapping: dict

    :return: A new data structure with values replaced according to the mapping
    :rtype: Any

    :example:
        >>> data = {'a': [1, 2, 3], 'b': {'x': 1}}
        >>> mapping = {1: 'one', 2: 'two'}
        >>> vreplace(data, mapping)
        {'a': ['one', 'two', 3], 'b': {'x': 'one'}}
    """
    if isinstance(v, (list, tuple)):
        return type(v)([vreplace(vitem, mapping) for vitem in v])
@@ -24,3 +47,57 @@ def vreplace(v, mapping):
            return v
        else:
            return mapping.get(v, v)


def _v_iternames(v):
    """
    Internal helper function that yields all hashable values from a nested data structure.

    :param v: The input data structure to traverse
    :type v: Any

    :yield: Hashable values found in the data structure
    :rtype: Generator
    """
    if isinstance(v, (list, tuple)):
        for item in v:
            yield from _v_iternames(item)
    elif isinstance(v, dict):
        for _, item in v.items():
            yield from _v_iternames(item)
    else:
        try:
            _ = hash(v)
        except TypeError:  # pragma: no cover
            pass
        else:
            yield v


def vnames(v, str_only: bool = True) -> List[str]:
    """
    Extracts unique values/names from a nested data structure.

    This function traverses through the input data structure and collects all unique
    hashable values. When str_only is True, it only collects string values.

    :param v: The input data structure to process
    :type v: Any
    :param str_only: If True, only string values are collected
    :type str_only: bool

    :return: A list of unique values found in the data structure
    :rtype: List[str]

    :example:
        >>> data = {'a': ['x', 'y', 1], 'b': {'z': 'x'}}
        >>> vnames(data)
        ['x', 'y', 'z']
        >>> vnames(data, str_only=False)
        ['x', 'y', 1, 'z']
    """
    name_set = set()
    for name in _v_iternames(v):
        if not str_only or isinstance(name, str):
            name_set.add(name)
    return list(name_set)
+86 −0
Original line number Diff line number Diff line
import numpy as np
import pytest
from PIL import Image

from imgutils.generic import classify_predict_score
from imgutils.generic.classify import _open_models_for_repo_id, classify_predict_fmt
from test.testings import get_testfile


@pytest.fixture(scope='module', autouse=True)
def _release_model_after_run():
    try:
        yield
    finally:
        _open_models_for_repo_id('deepghs/timms_mobilenet').clear()


@pytest.mark.unittest
class TestGenericClassify:
    def test_classify_predict_score(self):
@@ -130,3 +140,79 @@ class TestGenericClassify:
            'screwdriver': 0.029927952215075493,
            'chain saw, chainsaw': 0.02070867270231247,
        }, abs=1e-3)

    def test_classify_predict_fmt(self):
        image = Image.open(get_testfile('png_640.png'))
        results = classify_predict_fmt(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
        )
        assert results == pytest.approx({
            'n02966687': 0.48493319749832153,
            'n03481172': 0.1228410005569458,
            'n04482393': 0.07170269638299942,
            'n04154565': 0.029927952215075493,
            'n03000684': 0.02070867270231247
        }, abs=1e-3)

    def test_classify_predict_fmt_complex(self):
        image = Image.open(get_testfile('png_640.png'))
        results = classify_predict_fmt(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
            fmt={
                'scores-top10': 'scores-top10',
                'scores-top10-descriptions': 'scores-top10-descriptions',
                'scores-top5-definitions': 'scores-top5-definitions',
                'scores-top5-descriptions': 'scores-top5-descriptions',
                'embedding': 'embedding',
            }
        )
        assert results['scores-top10'] == pytest.approx({
            'n02966687': 0.48493319749832153,
            'n03481172': 0.1228410005569458,
            'n04482393': 0.07170269638299942,
            'n04154565': 0.029927952215075493,
            'n03000684': 0.02070867270231247,
            'n03498962': 0.019339734688401222,
            'n03444034': 0.013918918557465076,
            'n03995372': 0.009074677713215351,
            'n03794056': 0.00785701535642147,
            'n03384352': 0.007194260135293007
        }, abs=1e-3)
        assert results['scores-top10-descriptions'] == pytest.approx({
            "carpenter's kit, tool kit": 0.48493319749832153,
            'hammer': 0.1228410005569458,
            'tricycle, trike, velocipede': 0.07170269638299942,
            'screwdriver': 0.029927952215075493,
            'chain saw, chainsaw': 0.02070867270231247,
            'hatchet': 0.019339734688401222,
            'go-kart': 0.013918918557465076,
            'power drill': 0.009074677713215351,
            'mousetrap': 0.00785701535642147,
            'forklift': 0.007194260135293007
        }, abs=1e-3)
        assert results['scores-top5-definitions'] == pytest.approx({
            "a set of carpenter's tools": 0.48493319749832153,
            'a hand tool with a heavy rigid head and a handle; used to deliver an impulsive force by striking': 0.1228410005569458,
            'a vehicle with three wheels that is moved by foot pedals': 0.07170269638299942,
            'a hand tool for driving screws; has a tip that fits into the head of a screw': 0.029927952215075493,
            'portable power saw; teeth linked to form an endless chain': 0.02070867270231247
        }, abs=1e-3)
        assert results['scores-top5-descriptions'] == pytest.approx({
            "carpenter's kit, tool kit": 0.48493319749832153,
            'hammer': 0.1228410005569458,
            'tricycle, trike, velocipede': 0.07170269638299942,
            'screwdriver': 0.029927952215075493,
            'chain saw, chainsaw': 0.02070867270231247
        }, abs=1e-3)
        # np.save(get_testfile('png_640_emb.npy'), results['embedding'])
        assert results['embedding'].shape == (1280,)
        expected_embedding = np.load(get_testfile('png_640_emb.npy'))
        emb_1 = results['embedding'] / np.linalg.norm(results['embedding'], axis=-1, keepdims=True)
        emb_2 = expected_embedding / np.linalg.norm(expected_embedding, axis=-1, keepdims=True)
        emb_sims = (emb_1 * emb_2).sum()
        assert emb_sims >= 0.99, 'Direction not match with expected embedding.'
        assert np.linalg.norm(results['embedding']) == pytest.approx(np.linalg.norm(expected_embedding))
+5.13 KiB

File added.

No diff preview for this file type.

Loading