Unverified Commit 4ab6c58e authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #132 from deepghs/dev/cls

dev(narugo): add classifier multi-label system for generic timm exported models
parents a74b573c 33902e1e
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -124,7 +124,7 @@ jobs:
      - name: Run unittest
        env:
          CI: 'true'
          REMOTE_PIXIV_SESSION_INDEX_URL: ${{ secrets.REMOTE_PIXIV_SESSION_INDEX_URL }}
          HF_TOKEN: ${{ secrets.HF_TOKEN }}
        shell: bash
        run: |
          make unittest IS_WIN=${{ env.IS_WIN }} IS_MAC=${{ env.IS_MAC }}
+104 −16
Original line number Diff line number Diff line
@@ -21,8 +21,10 @@ from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import hf_hub_download, HfFileSystem
from huggingface_hub.errors import EntryNotFoundError

from ..data import rgb_encode, ImageTyping, load_image
from ..preprocess import create_pillow_transforms
from ..utils import open_onnx_model, ts_lru_cache

try:
@@ -133,6 +135,7 @@ class ClassifyModel:
        self._model_names = None
        self._models = {}
        self._labels = {}
        self._preprocesses = {}
        self._hf_token = hf_token
        self._global_lock = Lock()
        self._model_lock = Lock()
@@ -214,7 +217,7 @@ class ClassifyModel:

        return self._models[model_name]

    def _open_label(self, model_name: str) -> List[str]:
    def _open_label(self, model_name: str) -> Dict[str, np.ndarray]:
        """
        Load and cache model labels from metadata.

@@ -225,7 +228,7 @@ class ClassifyModel:
        :type model_name: str

        :return: List of model labels
        :rtype: List[str]
        :rtype: Dict[str, np.ndarray]

        :raises RuntimeError: If label loading fails
        """
@@ -237,10 +240,36 @@ class ClassifyModel:
                        f'{model_name}/meta.json',
                        token=self._get_hf_token(),
                ), 'r') as f:
                    self._labels[model_name] = json.load(f)['labels']
                    meta_info = json.load(f)
                    d_groups = {
                        **(meta_info.get('other_labels') or {}),
                        'default': meta_info['labels']
                    }
                    self._labels[model_name] = {
                        key: np.array(labels)
                        for key, labels in d_groups.items()
                    }

        return self._labels[model_name]

    def _open_preprocess(self, model_name: str):
        with self._model_lock:
            if model_name not in self._preprocesses:
                try:
                    pfile = hf_hub_download(
                        self.repo_id,
                        f'{model_name}/preprocess.json',
                        token=self._get_hf_token(),
                    )
                except EntryNotFoundError:
                    self._preprocesses[model_name] = None
                else:
                    with open(pfile, 'r') as f:
                        stages_info = json.load(f)['stages']
                        self._preprocesses[model_name] = create_pillow_transforms(stages_info)

            return self._preprocesses[model_name]

    def _raw_predict(self, image: ImageTyping, model_name: str):
        """
        Generate raw model predictions for an input image.
@@ -271,6 +300,10 @@ class ClassifyModel:
        if self._fn_preprocess:
            image = self._fn_preprocess(image)

        preprocess = self._open_preprocess(model_name=model_name)
        if preprocess:
            input_ = preprocess(image)[None, ...]
        else:
            if isinstance(height, int) and isinstance(width, int):
                input_ = _img_encode(image, size=(width, height))[None, ...]
            else:
@@ -278,7 +311,8 @@ class ClassifyModel:
        output, = self._open_model(model_name).run(['output'], {'input': input_})
        return output

    def predict_score(self, image: ImageTyping, model_name: str) -> Dict[str, float]:
    def predict_score(self, image: ImageTyping, model_name: str,
                      label_group: str = 'default', topk: Optional[int] = 20) -> Dict[str, float]:
        """
        Predict the scores for each class using the specified model.

@@ -288,6 +322,10 @@ class ClassifyModel:
        :type image: ImageTyping
        :param model_name: The name of the model to use for prediction.
        :type model_name: str
        :param label_group: Label group for the classification result.
        :type label_group: str
        :param topk: Top-K result. Default is 20, return all results when None ia assigned.
        :type topk: Optional[int]

        :return: A dictionary mapping class labels to their predicted scores.
        :rtype: Dict[str, float]
@@ -296,10 +334,20 @@ class ClassifyModel:
        :raises RuntimeError: If there's an error during prediction.
        """
        output = self._raw_predict(image, model_name)
        values = dict(zip(self._open_label(model_name), map(lambda x: x.item(), output[0])))
        labels = self._open_label(model_name)[label_group]
        scores = output[0]
        if topk and topk < labels.shape[-1]:
            indices = np.argpartition(scores, -topk)[-topk:]
            indices = indices[np.argsort(-scores[indices], kind='mergesort')]
        else:
            indices = np.argsort(-scores, kind='mergesort')
        labels, scores = labels[indices], scores[indices]

        # noinspection PyTypeChecker
        values = dict(zip(labels.tolist(), scores.tolist()))
        return values

    def predict(self, image: ImageTyping, model_name: str) -> Tuple[str, float]:
    def predict(self, image: ImageTyping, model_name: str, label_group: str = 'default') -> Tuple[str, float]:
        """
        Predict the class with the highest score for the given image.

@@ -309,6 +357,8 @@ class ClassifyModel:
        :type image: ImageTyping
        :param model_name: The name of the model to use for prediction.
        :type model_name: str
        :param label_group: Label group for the classification result.
        :type label_group: str

        :return: A tuple containing the predicted class label and its score.
        :rtype: Tuple[str, float]
@@ -318,7 +368,7 @@ class ClassifyModel:
        """
        output = self._raw_predict(image, model_name)[0]
        max_id = np.argmax(output)
        return self._open_label(model_name)[max_id], output[max_id].item()
        return self._open_label(model_name)[label_group][max_id], output[max_id].item()

    def clear(self):
        """
@@ -366,21 +416,43 @@ class ClassifyModel:

        with gr.Row():
            with gr.Column():
                with gr.Row():
                    gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                    gr_label_group = gr.Dropdown(list(self._open_label(default_model_name).keys()),
                                                 value='default', label='Label Group')
                with gr.Row():
                    gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                gr_output = gr.Label(label='Prediction')

            def _fn_label_group(new_model_name, old_label_group):
                labels_info = self._open_label(new_model_name)
                return gr.Dropdown(
                    list(labels_info.keys()),
                    value=old_label_group if old_label_group in labels_info else 'default',
                    label='Label Group'
                )

            gr_submit.click(
                self.predict_score,
                inputs=[
                    gr_input_image,
                    gr_model,
                    gr_label_group,
                ],
                outputs=[gr_output],
            )
            gr_model.change(
                _fn_label_group,
                inputs=[
                    gr_model,
                    gr_label_group
                ],
                outputs=[gr_label_group]
            )

    def launch_demo(self, default_model_name: Optional[str] = None,
                    server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
@@ -444,6 +516,7 @@ def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> Cl


def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str,
                           label_group: str = 'default', topk: Optional[int] = 20,
                           hf_token: Optional[str] = None) -> Dict[str, float]:
    """
    Predict the scores for each class using the specified model and repository.
@@ -456,6 +529,10 @@ def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str,
    :type repo_id: str
    :param model_name: The name of the model to use for prediction.
    :type model_name: str
    :param label_group: Label group for the classification result.
    :type label_group: str
    :param topk: Top-K result. Default is 20, return all results when None ia assigned.
    :type topk: Optional[int]
    :param hf_token: Optional Hugging Face authentication token.
    :type hf_token: Optional[str]

@@ -465,10 +542,15 @@ def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str,
    :raises ValueError: If the model name or repository ID is invalid.
    :raises RuntimeError: If there's an error during prediction.
    """
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict_score(image, model_name)
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict_score(
        image=image,
        model_name=model_name,
        label_group=label_group,
        topk=topk,
    )


def classify_predict(image: ImageTyping, repo_id: str, model_name: str,
def classify_predict(image: ImageTyping, repo_id: str, model_name: str, label_group: str = 'default',
                     hf_token: Optional[str] = None) -> Tuple[str, float]:
    """
    Predict the class with the highest score using the specified model and repository.
@@ -481,6 +563,8 @@ def classify_predict(image: ImageTyping, repo_id: str, model_name: str,
    :type repo_id: str
    :param model_name: The name of the model to use for prediction.
    :type model_name: str
    :param label_group: Label group for the classification result.
    :type label_group: str
    :param hf_token: Optional Hugging Face authentication token.
    :type hf_token: Optional[str]

@@ -490,4 +574,8 @@ def classify_predict(image: ImageTyping, repo_id: str, model_name: str,
    :raises ValueError: If the model name or repository ID is invalid.
    :raises RuntimeError: If there's an error during prediction.
    """
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict(image, model_name)
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict(
        image=image,
        model_name=model_name,
        label_group=label_group,
    )
+0 −0

Empty file added.

+132 −0
Original line number Diff line number Diff line
import pytest
from PIL import Image

from imgutils.generic import classify_predict_score
from test.testings import get_testfile


@pytest.mark.unittest
class TestGenericClassify:
    def test_classify_predict_score(self):
        image = Image.open(get_testfile('png_640.png'))
        scores = classify_predict_score(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
        )
        assert scores == pytest.approx({
            'n02966687': 0.48493319749832153,
            'n03481172': 0.1228410005569458,
            'n04482393': 0.07170269638299942,
            'n04154565': 0.029927952215075493,
            'n03000684': 0.02070867270231247,
            'n03498962': 0.019339734688401222,
            'n03444034': 0.013918918557465076,
            'n03995372': 0.009074677713215351,
            'n03794056': 0.00785701535642147,
            'n03384352': 0.007194260135293007,
            'n02930766': 0.005858728662133217,
            'n02835271': 0.005415018182247877,
            'n04336792': 0.005253748968243599,
            'n04509417': 0.004338286351412535,
            'n03792782': 0.004189436789602041,
            'n03532672': 0.004000757820904255,
            'n03109150': 0.0034237923100590706,
            'n04517823': 0.0027278559282422066,
            'n03126707': 0.0026790976990014315,
            'n02879718': 0.0026228304486721754
        }, abs=1e-3)

    def test_classify_predict_score_group1(self):
        image = Image.open(get_testfile('png_640.png'))
        scores = classify_predict_score(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
            label_group='descriptions',
        )
        assert scores == pytest.approx({
            "carpenter's kit, tool kit": 0.48493319749832153,
            'hammer': 0.1228410005569458,
            'tricycle, trike, velocipede': 0.07170269638299942,
            'screwdriver': 0.029927952215075493,
            'chain saw, chainsaw': 0.02070867270231247,
            'hatchet': 0.019339734688401222,
            'go-kart': 0.013918918557465076,
            'power drill': 0.009074677713215351, 'mousetrap': 0.00785701535642147,
            'forklift': 0.007194260135293007,
            'cab, hack, taxi, taxicab': 0.005858728662133217,
            'bicycle-built-for-two, tandem bicycle, tandem': 0.005415018182247877,
            'stretcher': 0.005253748968243599,
            'unicycle, monocycle': 0.004338286351412535,
            'mountain bike, all-terrain bike, off-roader': 0.004189436789602041,
            'hook, claw': 0.004000757820904255,
            'corkscrew, bottle screw': 0.0034237923100590706,
            'vacuum, vacuum cleaner': 0.0027278559282422066,
            'crane': 0.0026790976990014315,
            'bow': 0.0026228304486721754
        }, abs=1e-3)

    def test_classify_predict_score_group2(self):
        image = Image.open(get_testfile('png_640.png'))
        scores = classify_predict_score(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
            label_group='definitions',
        )
        assert scores == pytest.approx({
            "a set of carpenter's tools": 0.48493319749832153,
            'a hand tool with a heavy rigid head and a handle; used to deliver an impulsive force by striking': 0.1228410005569458,
            'a vehicle with three wheels that is moved by foot pedals': 0.07170269638299942,
            'a hand tool for driving screws; has a tip that fits into the head of a screw': 0.029927952215075493,
            'portable power saw; teeth linked to form an endless chain': 0.02070867270231247,
            'a small ax with a short handle used with one hand (usually to chop wood)': 0.019339734688401222,
            'a small low motor vehicle with four wheels and an open framework; used for racing': 0.013918918557465076,
            'a power tool for drilling holes into hard materials': 0.009074677713215351,
            'a trap for catching mice': 0.00785701535642147,
            'a small industrial vehicle with a power operated forked platform in front that can be inserted under loads to lift and move them': 0.007194260135293007,
            'a car driven by a person whose job is to take passengers where they want to go in exchange for money': 0.005858728662133217,
            'a bicycle with two sets of pedals and two seats': 0.005415018182247877,
            'a litter for transporting people who are ill or wounded or dead; usually consists of a sheet of canvas stretched between two poles': 0.005253748968243599,
            'a vehicle with a single wheel that is driven by pedals': 0.004338286351412535,
            'a bicycle with a sturdy frame and fat tires; originally designed for riding in mountainous country': 0.004189436789602041,
            'a mechanical device that is curved or bent to suspend or hold or pull something': 0.004000757820904255,
            'a bottle opener that pulls corks': 0.0034237923100590706,
            'an electrical home appliance that cleans by suction': 0.0027278559282422066,
            'lifts and moves heavy objects; lifting tackle is suspended from a pivoted boom that rotates around a vertical axis': 0.0026790976990014315,
            'a weapon for shooting arrows, composed of a curved piece of resilient wood with a taut cord to propel the arrow': 0.0026228304486721754
        }, abs=1e-3)

    def test_classify_predict_score_top5(self):
        image = Image.open(get_testfile('png_640.png'))
        scores = classify_predict_score(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
            topk=5,
        )
        assert scores == pytest.approx({
            'n02966687': 0.48493319749832153,
            'n03481172': 0.1228410005569458,
            'n04482393': 0.07170269638299942,
            'n04154565': 0.029927952215075493,
            'n03000684': 0.02070867270231247,
        }, abs=1e-3)

    def test_classify_predict_score_top5_group1(self):
        image = Image.open(get_testfile('png_640.png'))
        scores = classify_predict_score(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
            label_group='descriptions',
            topk=5,
        )
        assert scores == pytest.approx({
            "carpenter's kit, tool kit": 0.48493319749832153,
            'hammer': 0.1228410005569458,
            'tricycle, trike, velocipede': 0.07170269638299942,
            'screwdriver': 0.029927952215075493,
            'chain saw, chainsaw': 0.02070867270231247,
        }, abs=1e-3)