Commit c13aa7bf authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add unittest for new classify features

parent 49efe9b6
Loading
Loading
Loading
Loading
+22 −4
Original line number Diff line number Diff line
@@ -339,6 +339,8 @@ class ClassifyModel:
        if topk and topk < labels.shape[-1]:
            indices = np.argpartition(scores, -topk)[-topk:]
            indices = indices[np.argsort(-scores[indices], kind='mergesort')]
        else:
            indices = np.argsort(-scores, kind='mergesort')
        labels, scores = labels[indices], scores[indices]

        # noinspection PyTypeChecker
@@ -514,6 +516,7 @@ def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> Cl


def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str,
                           label_group: str = 'default', topk: Optional[int] = 20,
                           hf_token: Optional[str] = None) -> Dict[str, float]:
    """
    Predict the scores for each class using the specified model and repository.
@@ -526,6 +529,10 @@ def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str,
    :type repo_id: str
    :param model_name: The name of the model to use for prediction.
    :type model_name: str
    :param label_group: Label group for the classification result.
    :type label_group: str
    :param topk: Top-K result. Default is 20, return all results when None ia assigned.
    :type topk: Optional[int]
    :param hf_token: Optional Hugging Face authentication token.
    :type hf_token: Optional[str]

@@ -535,10 +542,15 @@ def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str,
    :raises ValueError: If the model name or repository ID is invalid.
    :raises RuntimeError: If there's an error during prediction.
    """
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict_score(image, model_name)
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict_score(
        image=image,
        model_name=model_name,
        label_group=label_group,
        topk=topk,
    )


def classify_predict(image: ImageTyping, repo_id: str, model_name: str,
def classify_predict(image: ImageTyping, repo_id: str, model_name: str, label_group: str = 'default',
                     hf_token: Optional[str] = None) -> Tuple[str, float]:
    """
    Predict the class with the highest score using the specified model and repository.
@@ -551,6 +563,8 @@ def classify_predict(image: ImageTyping, repo_id: str, model_name: str,
    :type repo_id: str
    :param model_name: The name of the model to use for prediction.
    :type model_name: str
    :param label_group: Label group for the classification result.
    :type label_group: str
    :param hf_token: Optional Hugging Face authentication token.
    :type hf_token: Optional[str]

@@ -560,4 +574,8 @@ def classify_predict(image: ImageTyping, repo_id: str, model_name: str,
    :raises ValueError: If the model name or repository ID is invalid.
    :raises RuntimeError: If there's an error during prediction.
    """
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict(image, model_name)
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict(
        image=image,
        model_name=model_name,
        label_group=label_group,
    )
+0 −0

Empty file added.

+135 −0
Original line number Diff line number Diff line
import pytest
from PIL import Image

from imgutils.generic import classify_predict_score
from test.testings import get_testfile


@pytest.mark.unittest
class TestGenericClassify:
    def test_classify_predict_score(self):
        image = Image.open(get_testfile('png_640.png'))
        scores = classify_predict_score(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
            # label_group='descriptions',
            # label_group='definitions',
            # topk=None,
        )
        assert scores == pytest.approx({
            'n02966687': 0.48493319749832153,
            'n03481172': 0.1228410005569458,
            'n04482393': 0.07170269638299942,
            'n04154565': 0.029927952215075493,
            'n03000684': 0.02070867270231247,
            'n03498962': 0.019339734688401222,
            'n03444034': 0.013918918557465076,
            'n03995372': 0.009074677713215351,
            'n03794056': 0.00785701535642147,
            'n03384352': 0.007194260135293007,
            'n02930766': 0.005858728662133217,
            'n02835271': 0.005415018182247877,
            'n04336792': 0.005253748968243599,
            'n04509417': 0.004338286351412535,
            'n03792782': 0.004189436789602041,
            'n03532672': 0.004000757820904255,
            'n03109150': 0.0034237923100590706,
            'n04517823': 0.0027278559282422066,
            'n03126707': 0.0026790976990014315,
            'n02879718': 0.0026228304486721754
        })

    def test_classify_predict_score_group1(self):
        image = Image.open(get_testfile('png_640.png'))
        scores = classify_predict_score(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
            label_group='descriptions',
        )
        assert scores == pytest.approx({
            "carpenter's kit, tool kit": 0.48493319749832153,
            'hammer': 0.1228410005569458,
            'tricycle, trike, velocipede': 0.07170269638299942,
            'screwdriver': 0.029927952215075493,
            'chain saw, chainsaw': 0.02070867270231247,
            'hatchet': 0.019339734688401222,
            'go-kart': 0.013918918557465076,
            'power drill': 0.009074677713215351, 'mousetrap': 0.00785701535642147,
            'forklift': 0.007194260135293007,
            'cab, hack, taxi, taxicab': 0.005858728662133217,
            'bicycle-built-for-two, tandem bicycle, tandem': 0.005415018182247877,
            'stretcher': 0.005253748968243599,
            'unicycle, monocycle': 0.004338286351412535,
            'mountain bike, all-terrain bike, off-roader': 0.004189436789602041,
            'hook, claw': 0.004000757820904255,
            'corkscrew, bottle screw': 0.0034237923100590706,
            'vacuum, vacuum cleaner': 0.0027278559282422066,
            'crane': 0.0026790976990014315,
            'bow': 0.0026228304486721754
        })

    def test_classify_predict_score_group2(self):
        image = Image.open(get_testfile('png_640.png'))
        scores = classify_predict_score(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
            label_group='definitions',
        )
        assert scores == pytest.approx({
            "a set of carpenter's tools": 0.48493319749832153,
            'a hand tool with a heavy rigid head and a handle; used to deliver an impulsive force by striking': 0.1228410005569458,
            'a vehicle with three wheels that is moved by foot pedals': 0.07170269638299942,
            'a hand tool for driving screws; has a tip that fits into the head of a screw': 0.029927952215075493,
            'portable power saw; teeth linked to form an endless chain': 0.02070867270231247,
            'a small ax with a short handle used with one hand (usually to chop wood)': 0.019339734688401222,
            'a small low motor vehicle with four wheels and an open framework; used for racing': 0.013918918557465076,
            'a power tool for drilling holes into hard materials': 0.009074677713215351,
            'a trap for catching mice': 0.00785701535642147,
            'a small industrial vehicle with a power operated forked platform in front that can be inserted under loads to lift and move them': 0.007194260135293007,
            'a car driven by a person whose job is to take passengers where they want to go in exchange for money': 0.005858728662133217,
            'a bicycle with two sets of pedals and two seats': 0.005415018182247877,
            'a litter for transporting people who are ill or wounded or dead; usually consists of a sheet of canvas stretched between two poles': 0.005253748968243599,
            'a vehicle with a single wheel that is driven by pedals': 0.004338286351412535,
            'a bicycle with a sturdy frame and fat tires; originally designed for riding in mountainous country': 0.004189436789602041,
            'a mechanical device that is curved or bent to suspend or hold or pull something': 0.004000757820904255,
            'a bottle opener that pulls corks': 0.0034237923100590706,
            'an electrical home appliance that cleans by suction': 0.0027278559282422066,
            'lifts and moves heavy objects; lifting tackle is suspended from a pivoted boom that rotates around a vertical axis': 0.0026790976990014315,
            'a weapon for shooting arrows, composed of a curved piece of resilient wood with a taut cord to propel the arrow': 0.0026228304486721754
        })

    def test_classify_predict_score_top5(self):
        image = Image.open(get_testfile('png_640.png'))
        scores = classify_predict_score(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
            topk=5,
        )
        assert scores == pytest.approx({
            'n02966687': 0.48493319749832153,
            'n03481172': 0.1228410005569458,
            'n04482393': 0.07170269638299942,
            'n04154565': 0.029927952215075493,
            'n03000684': 0.02070867270231247,
        })

    def test_classify_predict_score_top5_group1(self):
        image = Image.open(get_testfile('png_640.png'))
        scores = classify_predict_score(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
            label_group='descriptions',
            topk=5,
        )
        assert scores == pytest.approx({
            "carpenter's kit, tool kit": 0.48493319749832153,
            'hammer': 0.1228410005569458,
            'tricycle, trike, velocipede': 0.07170269638299942,
            'screwdriver': 0.029927952215075493,
            'chain saw, chainsaw': 0.02070867270231247,
        })