Loading imgutils/generic/classify.py +32 −12 Original line number Diff line number Diff line Loading @@ -17,12 +17,13 @@ from threading import Lock from typing import Tuple, Optional, List, Dict, Callable import numpy as np import requests from PIL import Image from hfutils.operate import get_hf_client from hfutils.repository import hf_hub_repo_url from hfutils.utils import hf_fs_path, hf_normpath from huggingface_hub import hf_hub_download, HfFileSystem from huggingface_hub.errors import EntryNotFoundError from huggingface_hub.errors import EntryNotFoundError, OfflineModeIsEnabled from ..data import rgb_encode, ImageTyping, load_image from ..preprocess import create_pillow_transforms Loading Loading @@ -108,6 +109,7 @@ def _labels_scores_to_topk(labels: np.ndarray, scores: np.ndarray, topk: Optiona ImagePreprocessFunc = Callable[[Image.Image], Image.Image] _OFFLINE = object() class ClassifyModel: Loading Loading @@ -181,6 +183,7 @@ class ClassifyModel: """ with self._global_lock: if self._model_names is None: try: hf_fs = HfFileSystem(token=self._get_hf_token()) self._model_names = [ hf_normpath(os.path.dirname(os.path.relpath(item, self.repo_id))) Loading @@ -190,6 +193,12 @@ class ClassifyModel: filename='*/model.onnx', )) ] except ( requests.exceptions.ConnectionError, requests.exceptions.Timeout, OfflineModeIsEnabled, ): self._model_names = _OFFLINE return self._model_names Loading @@ -202,7 +211,10 @@ class ClassifyModel: :raises ValueError: If model name is not found in repository """ if model_name not in self.model_names: model_list = self.model_names if model_list is _OFFLINE: return # do not check when in offline mode if model_name not in model_list: raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, ' f'models {self.model_names!r} are available.') Loading Loading @@ -462,6 +474,10 @@ class ClassifyModel: # demo for classifier model _check_gradio_env() model_list = self.model_names if model_list is _OFFLINE and not default_model_name: raise EnvironmentError('You are in OFFLINE mode, ' 'you must assign a default model name to make this ui usable.') if not default_model_name: hf_client = get_hf_client(hf_token=self._get_hf_token()) selected_model_name, selected_time = None, None Loading @@ -481,7 +497,11 @@ class ClassifyModel: with gr.Row(): gr_input_image = gr.Image(type='pil', label='Original Image') with gr.Row(): if model_list is not _OFFLINE: gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model') else: gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model', interactive=False) gr_label_group = gr.Dropdown(list(self._open_label(default_model_name).keys()), value='default', label='Label Group') with gr.Row(): Loading test/generic/test_classify.py +32 −1 Original line number Diff line number Diff line from unittest.mock import patch import numpy as np import pytest from PIL import Image from huggingface_hub.utils import reset_sessions from imgutils.generic import classify_predict_score from imgutils.generic.classify import _open_models_for_repo_id, classify_predict_fmt Loading @@ -12,7 +15,18 @@ def _release_model_after_run(): try: yield finally: _open_models_for_repo_id('deepghs/timms_mobilenet').clear() _open_models_for_repo_id.cache_clear() @pytest.fixture() def clean_session(): reset_sessions() _open_models_for_repo_id.cache_clear() try: yield finally: reset_sessions() _open_models_for_repo_id.cache_clear() @pytest.mark.unittest Loading Loading @@ -216,3 +230,20 @@ class TestGenericClassify: emb_sims = (emb_1 * emb_2).sum() assert emb_sims >= 0.99, 'Direction not match with expected embedding.' assert np.linalg.norm(results['embedding']) == pytest.approx(np.linalg.norm(expected_embedding)) @patch("huggingface_hub.constants.HF_HUB_OFFLINE", True) def test_classify_predict_score_top5_offline_mode(self, clean_session): image = Image.open(get_testfile('png_640.png')) scores = classify_predict_score( image, repo_id='deepghs/timms_mobilenet', model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k', topk=5, ) assert scores == pytest.approx({ 'n02966687': 0.48493319749832153, 'n03481172': 0.1228410005569458, 'n04482393': 0.07170269638299942, 'n04154565': 0.029927952215075493, 'n03000684': 0.02070867270231247, }, abs=1e-3) Loading
imgutils/generic/classify.py +32 −12 Original line number Diff line number Diff line Loading @@ -17,12 +17,13 @@ from threading import Lock from typing import Tuple, Optional, List, Dict, Callable import numpy as np import requests from PIL import Image from hfutils.operate import get_hf_client from hfutils.repository import hf_hub_repo_url from hfutils.utils import hf_fs_path, hf_normpath from huggingface_hub import hf_hub_download, HfFileSystem from huggingface_hub.errors import EntryNotFoundError from huggingface_hub.errors import EntryNotFoundError, OfflineModeIsEnabled from ..data import rgb_encode, ImageTyping, load_image from ..preprocess import create_pillow_transforms Loading Loading @@ -108,6 +109,7 @@ def _labels_scores_to_topk(labels: np.ndarray, scores: np.ndarray, topk: Optiona ImagePreprocessFunc = Callable[[Image.Image], Image.Image] _OFFLINE = object() class ClassifyModel: Loading Loading @@ -181,6 +183,7 @@ class ClassifyModel: """ with self._global_lock: if self._model_names is None: try: hf_fs = HfFileSystem(token=self._get_hf_token()) self._model_names = [ hf_normpath(os.path.dirname(os.path.relpath(item, self.repo_id))) Loading @@ -190,6 +193,12 @@ class ClassifyModel: filename='*/model.onnx', )) ] except ( requests.exceptions.ConnectionError, requests.exceptions.Timeout, OfflineModeIsEnabled, ): self._model_names = _OFFLINE return self._model_names Loading @@ -202,7 +211,10 @@ class ClassifyModel: :raises ValueError: If model name is not found in repository """ if model_name not in self.model_names: model_list = self.model_names if model_list is _OFFLINE: return # do not check when in offline mode if model_name not in model_list: raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, ' f'models {self.model_names!r} are available.') Loading Loading @@ -462,6 +474,10 @@ class ClassifyModel: # demo for classifier model _check_gradio_env() model_list = self.model_names if model_list is _OFFLINE and not default_model_name: raise EnvironmentError('You are in OFFLINE mode, ' 'you must assign a default model name to make this ui usable.') if not default_model_name: hf_client = get_hf_client(hf_token=self._get_hf_token()) selected_model_name, selected_time = None, None Loading @@ -481,7 +497,11 @@ class ClassifyModel: with gr.Row(): gr_input_image = gr.Image(type='pil', label='Original Image') with gr.Row(): if model_list is not _OFFLINE: gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model') else: gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model', interactive=False) gr_label_group = gr.Dropdown(list(self._open_label(default_model_name).keys()), value='default', label='Label Group') with gr.Row(): Loading
test/generic/test_classify.py +32 −1 Original line number Diff line number Diff line from unittest.mock import patch import numpy as np import pytest from PIL import Image from huggingface_hub.utils import reset_sessions from imgutils.generic import classify_predict_score from imgutils.generic.classify import _open_models_for_repo_id, classify_predict_fmt Loading @@ -12,7 +15,18 @@ def _release_model_after_run(): try: yield finally: _open_models_for_repo_id('deepghs/timms_mobilenet').clear() _open_models_for_repo_id.cache_clear() @pytest.fixture() def clean_session(): reset_sessions() _open_models_for_repo_id.cache_clear() try: yield finally: reset_sessions() _open_models_for_repo_id.cache_clear() @pytest.mark.unittest Loading Loading @@ -216,3 +230,20 @@ class TestGenericClassify: emb_sims = (emb_1 * emb_2).sum() assert emb_sims >= 0.99, 'Direction not match with expected embedding.' assert np.linalg.norm(results['embedding']) == pytest.approx(np.linalg.norm(expected_embedding)) @patch("huggingface_hub.constants.HF_HUB_OFFLINE", True) def test_classify_predict_score_top5_offline_mode(self, clean_session): image = Image.open(get_testfile('png_640.png')) scores = classify_predict_score( image, repo_id='deepghs/timms_mobilenet', model_name='mobilenetv4_hybrid_medium.ix_e550_r384_in1k', topk=5, ) assert scores == pytest.approx({ 'n02966687': 0.48493319749832153, 'n03481172': 0.1228410005569458, 'n04482393': 0.07170269638299942, 'n04154565': 0.029927952215075493, 'n03000684': 0.02070867270231247, }, abs=1e-3)