Loading imgutils/generic/classify.py +12 −8 Original line number Diff line number Diff line Loading @@ -37,6 +37,7 @@ __all__ = [ 'ClassifyModel', 'classify_predict_score', 'classify_predict', 'classify_predict_fmt', ] Loading Loading @@ -390,14 +391,8 @@ class ClassifyModel: for vname in vnames(fmt, str_only=True): matching = re.fullmatch(r'^scores(-top(?P<topk>\d+))?(-(?P<label_group>[a-zA-Z\d_]+))?$', vname) if matching: if matching.group('topk'): topk = int(matching.group('topk')) else: topk = None if matching.group('label_group'): group_label = matching.group('label_group') else: group_label = 'default' topk = int(matching.group('topk')) if matching.group('topk') else None group_label = matching.group('label_group') if matching.group('label_group') else 'default' vname_to_spair[vname] = (topk, group_label) if (topk, group_label) not in d_scores: d_scores[(topk, group_label)] = _labels_scores_to_topk( Loading Loading @@ -621,3 +616,12 @@ def classify_predict(image: ImageTyping, repo_id: str, model_name: str, label_gr model_name=model_name, label_group=label_group, ) def classify_predict_fmt(image: ImageTyping, repo_id: str, model_name: str, fmt='scores-top5', hf_token: Optional[str] = None): return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict_fmt( image=image, model_name=model_name, fmt=fmt, ) test/generic/test_classify.py +68 −1 Original line number Diff line number Diff line Loading @@ -2,7 +2,7 @@ import pytest from PIL import Image from imgutils.generic import classify_predict_score from imgutils.generic.classify import _open_models_for_repo_id from imgutils.generic.classify import _open_models_for_repo_id, classify_predict_fmt from test.testings import get_testfile Loading Loading @@ -139,3 +139,70 @@ class TestGenericClassify: 'screwdriver': 0.029927952215075493, 'chain saw, chainsaw': 0.02070867270231247, }, abs=1e-3) def test_classify_predict_fmt(self): image = Image.open(get_testfile('png_640.png')) results = classify_predict_fmt( image, repo_id='deepghs/timms_mobilenet', model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k', ) assert results == pytest.approx({ 'n02966687': 0.48493319749832153, 'n03481172': 0.1228410005569458, 'n04482393': 0.07170269638299942, 'n04154565': 0.029927952215075493, 'n03000684': 0.02070867270231247 }, abs=1e-3) def test_classify_predict_fmt_complex(self): image = Image.open(get_testfile('png_640.png')) results = classify_predict_fmt( image, repo_id='deepghs/timms_mobilenet', model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k', fmt={ 'scores-top10': 'scores-top10', 'scores-top10-descriptions': 'scores-top10-descriptions', 'scores-top5-definitions': 'scores-top5-definitions', 'scores-top5-descriptions': 'scores-top5-descriptions', } ) assert results['scores-top10'] == pytest.approx({ 'n02966687': 0.48493319749832153, 'n03481172': 0.1228410005569458, 'n04482393': 0.07170269638299942, 'n04154565': 0.029927952215075493, 'n03000684': 0.02070867270231247, 'n03498962': 0.019339734688401222, 'n03444034': 0.013918918557465076, 'n03995372': 0.009074677713215351, 'n03794056': 0.00785701535642147, 'n03384352': 0.007194260135293007 }, abs=1e-3) assert results['scores-top10-descriptions'] == pytest.approx({ "carpenter's kit, tool kit": 0.48493319749832153, 'hammer': 0.1228410005569458, 'tricycle, trike, velocipede': 0.07170269638299942, 'screwdriver': 0.029927952215075493, 'chain saw, chainsaw': 0.02070867270231247, 'hatchet': 0.019339734688401222, 'go-kart': 0.013918918557465076, 'power drill': 0.009074677713215351, 'mousetrap': 0.00785701535642147, 'forklift': 0.007194260135293007 }, abs=1e-3) assert results['scores-top5-definitions'] == pytest.approx({ "a set of carpenter's tools": 0.48493319749832153, 'a hand tool with a heavy rigid head and a handle; used to deliver an impulsive force by striking': 0.1228410005569458, 'a vehicle with three wheels that is moved by foot pedals': 0.07170269638299942, 'a hand tool for driving screws; has a tip that fits into the head of a screw': 0.029927952215075493, 'portable power saw; teeth linked to form an endless chain': 0.02070867270231247 }, abs=1e-3) assert results['scores-top5-descriptions'] == pytest.approx({ "carpenter's kit, tool kit": 0.48493319749832153, 'hammer': 0.1228410005569458, 'tricycle, trike, velocipede': 0.07170269638299942, 'screwdriver': 0.029927952215075493, 'chain saw, chainsaw': 0.02070867270231247 }, abs=1e-3) Loading
imgutils/generic/classify.py +12 −8 Original line number Diff line number Diff line Loading @@ -37,6 +37,7 @@ __all__ = [ 'ClassifyModel', 'classify_predict_score', 'classify_predict', 'classify_predict_fmt', ] Loading Loading @@ -390,14 +391,8 @@ class ClassifyModel: for vname in vnames(fmt, str_only=True): matching = re.fullmatch(r'^scores(-top(?P<topk>\d+))?(-(?P<label_group>[a-zA-Z\d_]+))?$', vname) if matching: if matching.group('topk'): topk = int(matching.group('topk')) else: topk = None if matching.group('label_group'): group_label = matching.group('label_group') else: group_label = 'default' topk = int(matching.group('topk')) if matching.group('topk') else None group_label = matching.group('label_group') if matching.group('label_group') else 'default' vname_to_spair[vname] = (topk, group_label) if (topk, group_label) not in d_scores: d_scores[(topk, group_label)] = _labels_scores_to_topk( Loading Loading @@ -621,3 +616,12 @@ def classify_predict(image: ImageTyping, repo_id: str, model_name: str, label_gr model_name=model_name, label_group=label_group, ) def classify_predict_fmt(image: ImageTyping, repo_id: str, model_name: str, fmt='scores-top5', hf_token: Optional[str] = None): return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict_fmt( image=image, model_name=model_name, fmt=fmt, )
test/generic/test_classify.py +68 −1 Original line number Diff line number Diff line Loading @@ -2,7 +2,7 @@ import pytest from PIL import Image from imgutils.generic import classify_predict_score from imgutils.generic.classify import _open_models_for_repo_id from imgutils.generic.classify import _open_models_for_repo_id, classify_predict_fmt from test.testings import get_testfile Loading Loading @@ -139,3 +139,70 @@ class TestGenericClassify: 'screwdriver': 0.029927952215075493, 'chain saw, chainsaw': 0.02070867270231247, }, abs=1e-3) def test_classify_predict_fmt(self): image = Image.open(get_testfile('png_640.png')) results = classify_predict_fmt( image, repo_id='deepghs/timms_mobilenet', model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k', ) assert results == pytest.approx({ 'n02966687': 0.48493319749832153, 'n03481172': 0.1228410005569458, 'n04482393': 0.07170269638299942, 'n04154565': 0.029927952215075493, 'n03000684': 0.02070867270231247 }, abs=1e-3) def test_classify_predict_fmt_complex(self): image = Image.open(get_testfile('png_640.png')) results = classify_predict_fmt( image, repo_id='deepghs/timms_mobilenet', model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k', fmt={ 'scores-top10': 'scores-top10', 'scores-top10-descriptions': 'scores-top10-descriptions', 'scores-top5-definitions': 'scores-top5-definitions', 'scores-top5-descriptions': 'scores-top5-descriptions', } ) assert results['scores-top10'] == pytest.approx({ 'n02966687': 0.48493319749832153, 'n03481172': 0.1228410005569458, 'n04482393': 0.07170269638299942, 'n04154565': 0.029927952215075493, 'n03000684': 0.02070867270231247, 'n03498962': 0.019339734688401222, 'n03444034': 0.013918918557465076, 'n03995372': 0.009074677713215351, 'n03794056': 0.00785701535642147, 'n03384352': 0.007194260135293007 }, abs=1e-3) assert results['scores-top10-descriptions'] == pytest.approx({ "carpenter's kit, tool kit": 0.48493319749832153, 'hammer': 0.1228410005569458, 'tricycle, trike, velocipede': 0.07170269638299942, 'screwdriver': 0.029927952215075493, 'chain saw, chainsaw': 0.02070867270231247, 'hatchet': 0.019339734688401222, 'go-kart': 0.013918918557465076, 'power drill': 0.009074677713215351, 'mousetrap': 0.00785701535642147, 'forklift': 0.007194260135293007 }, abs=1e-3) assert results['scores-top5-definitions'] == pytest.approx({ "a set of carpenter's tools": 0.48493319749832153, 'a hand tool with a heavy rigid head and a handle; used to deliver an impulsive force by striking': 0.1228410005569458, 'a vehicle with three wheels that is moved by foot pedals': 0.07170269638299942, 'a hand tool for driving screws; has a tip that fits into the head of a screw': 0.029927952215075493, 'portable power saw; teeth linked to form an endless chain': 0.02070867270231247 }, abs=1e-3) assert results['scores-top5-descriptions'] == pytest.approx({ "carpenter's kit, tool kit": 0.48493319749832153, 'hammer': 0.1228410005569458, 'tricycle, trike, velocipede': 0.07170269638299942, 'screwdriver': 0.029927952215075493, 'chain saw, chainsaw': 0.02070867270231247 }, abs=1e-3)