Commit 4214928c authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): migrate for classify

parent cb2407ba
Loading
Loading
Loading
Loading
+32 −12
Original line number Diff line number Diff line
@@ -17,12 +17,13 @@ from threading import Lock
from typing import Tuple, Optional, List, Dict, Callable

import numpy as np
import requests
from PIL import Image
from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import hf_hub_download, HfFileSystem
from huggingface_hub.errors import EntryNotFoundError
from huggingface_hub.errors import EntryNotFoundError, OfflineModeIsEnabled

from ..data import rgb_encode, ImageTyping, load_image
from ..preprocess import create_pillow_transforms
@@ -108,6 +109,7 @@ def _labels_scores_to_topk(labels: np.ndarray, scores: np.ndarray, topk: Optiona


ImagePreprocessFunc = Callable[[Image.Image], Image.Image]
_OFFLINE = object()


class ClassifyModel:
@@ -181,6 +183,7 @@ class ClassifyModel:
        """
        with self._global_lock:
            if self._model_names is None:
                try:
                    hf_fs = HfFileSystem(token=self._get_hf_token())
                    self._model_names = [
                        hf_normpath(os.path.dirname(os.path.relpath(item, self.repo_id)))
@@ -190,6 +193,12 @@ class ClassifyModel:
                            filename='*/model.onnx',
                        ))
                    ]
                except (
                        requests.exceptions.ConnectionError,
                        requests.exceptions.Timeout,
                        OfflineModeIsEnabled,
                ):
                    self._model_names = _OFFLINE

        return self._model_names

@@ -202,7 +211,10 @@ class ClassifyModel:

        :raises ValueError: If model name is not found in repository
        """
        if model_name not in self.model_names:
        model_list = self.model_names
        if model_list is _OFFLINE:
            return  # do not check when in offline mode
        if model_name not in model_list:
            raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
                             f'models {self.model_names!r} are available.')

@@ -462,6 +474,10 @@ class ClassifyModel:
        # demo for classifier model
        _check_gradio_env()
        model_list = self.model_names
        if model_list is _OFFLINE and not default_model_name:
            raise EnvironmentError('You are in OFFLINE mode, '
                                   'you must assign a default model name to make this ui usable.')

        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
            selected_model_name, selected_time = None, None
@@ -481,7 +497,11 @@ class ClassifyModel:
                with gr.Row():
                    gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    if model_list is not _OFFLINE:
                        gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                    else:
                        gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model',
                                               interactive=False)
                    gr_label_group = gr.Dropdown(list(self._open_label(default_model_name).keys()),
                                                 value='default', label='Label Group')
                with gr.Row():
+32 −1
Original line number Diff line number Diff line
from unittest.mock import patch

import numpy as np
import pytest
from PIL import Image
from huggingface_hub.utils import reset_sessions

from imgutils.generic import classify_predict_score
from imgutils.generic.classify import _open_models_for_repo_id, classify_predict_fmt
@@ -12,7 +15,18 @@ def _release_model_after_run():
    try:
        yield
    finally:
        _open_models_for_repo_id('deepghs/timms_mobilenet').clear()
        _open_models_for_repo_id.cache_clear()


@pytest.fixture()
def clean_session():
    reset_sessions()
    _open_models_for_repo_id.cache_clear()
    try:
        yield
    finally:
        reset_sessions()
        _open_models_for_repo_id.cache_clear()


@pytest.mark.unittest
@@ -216,3 +230,20 @@ class TestGenericClassify:
        emb_sims = (emb_1 * emb_2).sum()
        assert emb_sims >= 0.99, 'Direction not match with expected embedding.'
        assert np.linalg.norm(results['embedding']) == pytest.approx(np.linalg.norm(expected_embedding))

    @patch("huggingface_hub.constants.HF_HUB_OFFLINE", True)
    def test_classify_predict_score_top5_offline_mode(self, clean_session):
        image = Image.open(get_testfile('png_640.png'))
        scores = classify_predict_score(
            image,
            repo_id='deepghs/timms_mobilenet',
            model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k',
            topk=5,
        )
        assert scores == pytest.approx({
            'n02966687': 0.48493319749832153,
            'n03481172': 0.1228410005569458,
            'n04482393': 0.07170269638299942,
            'n04154565': 0.029927952215075493,
            'n03000684': 0.02070867270231247,
        }, abs=1e-3)