Commit 3e2f04a2 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update classify

parent deb215eb
Loading
Loading
Loading
Loading
+12 −8
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ __all__ = [
    'ClassifyModel',
    'classify_predict_score',
    'classify_predict',
    'classify_predict_fmt',
]


@@ -390,14 +391,8 @@ class ClassifyModel:
        for vname in vnames(fmt, str_only=True):
            matching = re.fullmatch(r'^scores(-top(?P<topk>\d+))?(-(?P<label_group>[a-zA-Z\d_]+))?$', vname)
            if matching:
                if matching.group('topk'):
                    topk = int(matching.group('topk'))
                else:
                    topk = None
                if matching.group('label_group'):
                    group_label = matching.group('label_group')
                else:
                    group_label = 'default'
                topk = int(matching.group('topk')) if matching.group('topk') else None
                group_label = matching.group('label_group') if matching.group('label_group') else 'default'
                vname_to_spair[vname] = (topk, group_label)
                if (topk, group_label) not in d_scores:
                    d_scores[(topk, group_label)] = _labels_scores_to_topk(
@@ -621,3 +616,12 @@ def classify_predict(image: ImageTyping, repo_id: str, model_name: str, label_gr
        model_name=model_name,
        label_group=label_group,
    )


def classify_predict_fmt(image: ImageTyping, repo_id: str, model_name: str, fmt='scores-top5',
                         hf_token: Optional[str] = None):
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict_fmt(
        image=image,
        model_name=model_name,
        fmt=fmt,
    )
+68 −1
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ import pytest
from PIL import Image

from imgutils.generic import classify_predict_score
from imgutils.generic.classify import _open_models_for_repo_id
from imgutils.generic.classify import _open_models_for_repo_id, classify_predict_fmt
from test.testings import get_testfile


@@ -139,3 +139,70 @@ class TestGenericClassify:
            'screwdriver': 0.029927952215075493,
            'chain saw, chainsaw': 0.02070867270231247,
        }, abs=1e-3)

    def test_classify_predict_fmt(self):
        image = Image.open(get_testfile('png_640.png'))
        results = classify_predict_fmt(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
        )
        assert results == pytest.approx({
            'n02966687': 0.48493319749832153,
            'n03481172': 0.1228410005569458,
            'n04482393': 0.07170269638299942,
            'n04154565': 0.029927952215075493,
            'n03000684': 0.02070867270231247
        }, abs=1e-3)

    def test_classify_predict_fmt_complex(self):
        image = Image.open(get_testfile('png_640.png'))
        results = classify_predict_fmt(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
            fmt={
                'scores-top10': 'scores-top10',
                'scores-top10-descriptions': 'scores-top10-descriptions',
                'scores-top5-definitions': 'scores-top5-definitions',
                'scores-top5-descriptions': 'scores-top5-descriptions',
            }
        )
        assert results['scores-top10'] == pytest.approx({
            'n02966687': 0.48493319749832153,
            'n03481172': 0.1228410005569458,
            'n04482393': 0.07170269638299942,
            'n04154565': 0.029927952215075493,
            'n03000684': 0.02070867270231247,
            'n03498962': 0.019339734688401222,
            'n03444034': 0.013918918557465076,
            'n03995372': 0.009074677713215351,
            'n03794056': 0.00785701535642147,
            'n03384352': 0.007194260135293007
        }, abs=1e-3)
        assert results['scores-top10-descriptions'] == pytest.approx({
            "carpenter's kit, tool kit": 0.48493319749832153,
            'hammer': 0.1228410005569458,
            'tricycle, trike, velocipede': 0.07170269638299942,
            'screwdriver': 0.029927952215075493,
            'chain saw, chainsaw': 0.02070867270231247,
            'hatchet': 0.019339734688401222,
            'go-kart': 0.013918918557465076,
            'power drill': 0.009074677713215351,
            'mousetrap': 0.00785701535642147,
            'forklift': 0.007194260135293007
        }, abs=1e-3)
        assert results['scores-top5-definitions'] == pytest.approx({
            "a set of carpenter's tools": 0.48493319749832153,
            'a hand tool with a heavy rigid head and a handle; used to deliver an impulsive force by striking': 0.1228410005569458,
            'a vehicle with three wheels that is moved by foot pedals': 0.07170269638299942,
            'a hand tool for driving screws; has a tip that fits into the head of a screw': 0.029927952215075493,
            'portable power saw; teeth linked to form an endless chain': 0.02070867270231247
        }, abs=1e-3)
        assert results['scores-top5-descriptions'] == pytest.approx({
            "carpenter's kit, tool kit": 0.48493319749832153,
            'hammer': 0.1228410005569458,
            'tricycle, trike, velocipede': 0.07170269638299942,
            'screwdriver': 0.029927952215075493,
            'chain saw, chainsaw': 0.02070867270231247
        }, abs=1e-3)