Loading docs/source/api_doc/generic/classify.rst +1 −1 Original line number Diff line number Diff line Loading @@ -11,7 +11,7 @@ ClassifyModel ----------------------------------------- .. autoclass:: ClassifyModel :members: __init__, predict_score, predict, clear :members: __init__, predict_score, predict, clear, make_ui, launch_demo Loading imgutils/generic/classify.py +73 −0 Original line number Diff line number Diff line Loading @@ -23,12 +23,19 @@ from typing import Tuple, Optional, List, Dict import numpy as np from PIL import Image from hfutils.operate import get_hf_client from hfutils.repository import hf_hub_repo_url from hfutils.utils import hf_fs_path, hf_normpath from huggingface_hub import hf_hub_download, HfFileSystem from ..data import rgb_encode, ImageTyping, load_image from ..utils import open_onnx_model try: import gradio as gr except (ImportError, ModuleNotFoundError): gr = None __all__ = [ 'ClassifyModel', 'classify_predict_score', Loading @@ -36,6 +43,17 @@ __all__ = [ ] def _check_gradio_env(): """ Check if the Gradio library is installed and available. :raises EnvironmentError: If Gradio is not installed. """ if gr is None: raise EnvironmentError(f'Gradio required for launching webui-based demo.\n' f'Please install it with `pip install dghs-imgutils[demo]`.') def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): """ Loading Loading @@ -287,6 +305,61 @@ class ClassifyModel: self._models.clear() self._labels.clear() def make_ui(self, default_model_name: Optional[str] = None): _check_gradio_env() model_list = self.model_names if not default_model_name: hf_client = get_hf_client(hf_token=self._get_hf_token()) selected_model_name, selected_time = None, None for fileitem in hf_client.get_paths_info( repo_id=self.repo_id, repo_type='model', paths=[f'{model_name}/model.onnx' for model_name in model_list], expand=True, ): if not selected_time or fileitem.last_commit.date > selected_time: selected_model_name = os.path.dirname(fileitem.path) selected_time = fileitem.last_commit.date default_model_name = selected_model_name with gr.Row(): with gr.Column(): gr_input_image = gr.Image(type='pil', label='Original Image') gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model') gr_submit = gr.Button(value='Submit', variant='primary') with gr.Column(): gr_output = gr.Label(label='Prediction') gr_submit.click( self.predict_score, inputs=[ gr_input_image, gr_model, ], outputs=[gr_output], ) def launch_demo(self, default_model_name: Optional[str] = None, server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs): _check_gradio_env() with gr.Blocks() as demo: with gr.Row(): with gr.Column(): repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model') gr.HTML(f'<h2 style="text-align: center;">Classifier Demo For {self.repo_id}</h2>') gr.Markdown(f'This is the quick demo for classifier model [{self.repo_id}]({repo_url}). ' f'Powered by `dghs-imgutils`\'s quick demo module.') with gr.Row(): self.make_ui(default_model_name=default_model_name) demo.launch( server_name=server_name, server_port=server_port, **kwargs, ) @lru_cache() def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> ClassifyModel: Loading Loading
docs/source/api_doc/generic/classify.rst +1 −1 Original line number Diff line number Diff line Loading @@ -11,7 +11,7 @@ ClassifyModel ----------------------------------------- .. autoclass:: ClassifyModel :members: __init__, predict_score, predict, clear :members: __init__, predict_score, predict, clear, make_ui, launch_demo Loading
imgutils/generic/classify.py +73 −0 Original line number Diff line number Diff line Loading @@ -23,12 +23,19 @@ from typing import Tuple, Optional, List, Dict import numpy as np from PIL import Image from hfutils.operate import get_hf_client from hfutils.repository import hf_hub_repo_url from hfutils.utils import hf_fs_path, hf_normpath from huggingface_hub import hf_hub_download, HfFileSystem from ..data import rgb_encode, ImageTyping, load_image from ..utils import open_onnx_model try: import gradio as gr except (ImportError, ModuleNotFoundError): gr = None __all__ = [ 'ClassifyModel', 'classify_predict_score', Loading @@ -36,6 +43,17 @@ __all__ = [ ] def _check_gradio_env(): """ Check if the Gradio library is installed and available. :raises EnvironmentError: If Gradio is not installed. """ if gr is None: raise EnvironmentError(f'Gradio required for launching webui-based demo.\n' f'Please install it with `pip install dghs-imgutils[demo]`.') def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): """ Loading Loading @@ -287,6 +305,61 @@ class ClassifyModel: self._models.clear() self._labels.clear() def make_ui(self, default_model_name: Optional[str] = None): _check_gradio_env() model_list = self.model_names if not default_model_name: hf_client = get_hf_client(hf_token=self._get_hf_token()) selected_model_name, selected_time = None, None for fileitem in hf_client.get_paths_info( repo_id=self.repo_id, repo_type='model', paths=[f'{model_name}/model.onnx' for model_name in model_list], expand=True, ): if not selected_time or fileitem.last_commit.date > selected_time: selected_model_name = os.path.dirname(fileitem.path) selected_time = fileitem.last_commit.date default_model_name = selected_model_name with gr.Row(): with gr.Column(): gr_input_image = gr.Image(type='pil', label='Original Image') gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model') gr_submit = gr.Button(value='Submit', variant='primary') with gr.Column(): gr_output = gr.Label(label='Prediction') gr_submit.click( self.predict_score, inputs=[ gr_input_image, gr_model, ], outputs=[gr_output], ) def launch_demo(self, default_model_name: Optional[str] = None, server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs): _check_gradio_env() with gr.Blocks() as demo: with gr.Row(): with gr.Column(): repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model') gr.HTML(f'<h2 style="text-align: center;">Classifier Demo For {self.repo_id}</h2>') gr.Markdown(f'This is the quick demo for classifier model [{self.repo_id}]({repo_url}). ' f'Powered by `dghs-imgutils`\'s quick demo module.') with gr.Row(): self.make_ui(default_model_name=default_model_name) demo.launch( server_name=server_name, server_port=server_port, **kwargs, ) @lru_cache() def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> ClassifyModel: Loading