Unverified Commit d8c335d0 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #151 from deepghs/dev/offline

dev(narugo): guess what is this
parents 96d10def b257c7d0
Loading
Loading
Loading
Loading
+35 −14
Original line number Diff line number Diff line
@@ -17,12 +17,13 @@ from threading import Lock
from typing import Tuple, Optional, List, Dict, Callable

import numpy as np
import requests
from PIL import Image
from hfutils.operate import get_hf_client
from hfutils.operate import get_hf_client, get_hf_fs
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import hf_hub_download, HfFileSystem
from huggingface_hub.errors import EntryNotFoundError
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError, OfflineModeIsEnabled

from ..data import rgb_encode, ImageTyping, load_image
from ..preprocess import create_pillow_transforms
@@ -108,6 +109,7 @@ def _labels_scores_to_topk(labels: np.ndarray, scores: np.ndarray, topk: Optiona


ImagePreprocessFunc = Callable[[Image.Image], Image.Image]
_OFFLINE = object()


class ClassifyModel:
@@ -181,7 +183,8 @@ class ClassifyModel:
        """
        with self._global_lock:
            if self._model_names is None:
                hf_fs = HfFileSystem(token=self._get_hf_token())
                try:
                    hf_fs = get_hf_fs(hf_token=self._get_hf_token())
                    self._model_names = [
                        hf_normpath(os.path.dirname(os.path.relpath(item, self.repo_id)))
                        for item in hf_fs.glob(hf_fs_path(
@@ -190,6 +193,12 @@ class ClassifyModel:
                            filename='*/model.onnx',
                        ))
                    ]
                except (
                        requests.exceptions.ConnectionError,
                        requests.exceptions.Timeout,
                        OfflineModeIsEnabled,
                ):
                    self._model_names = _OFFLINE

        return self._model_names

@@ -202,7 +211,10 @@ class ClassifyModel:

        :raises ValueError: If model name is not found in repository
        """
        if model_name not in self.model_names:
        model_list = self.model_names
        if model_list is _OFFLINE:
            return  # do not check when in offline mode
        if model_name not in model_list:
            raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
                             f'models {self.model_names!r} are available.')

@@ -437,6 +449,7 @@ class ClassifyModel:

        This method frees up memory by removing all loaded models and labels from the cache.
        """
        self._model_names = None
        self._models.clear()
        self._labels.clear()
        self._preprocesses.clear()
@@ -462,6 +475,10 @@ class ClassifyModel:
        # demo for classifier model
        _check_gradio_env()
        model_list = self.model_names
        if model_list is _OFFLINE and not default_model_name:
            raise EnvironmentError('You are in OFFLINE mode, '
                                   'you must assign a default model name to make this ui usable.')

        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
            selected_model_name, selected_time = None, None
@@ -481,7 +498,11 @@ class ClassifyModel:
                with gr.Row():
                    gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    if model_list is not _OFFLINE:
                        gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                    else:
                        gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model',
                                               interactive=False)
                    gr_label_group = gr.Dropdown(list(self._open_label(default_model_name).keys()),
                                                 value='default', label='Label Group')
                with gr.Row():
+37 −13
Original line number Diff line number Diff line
@@ -27,10 +27,12 @@ from threading import Lock
from typing import List, Union, Optional, Any, Dict

import numpy as np
from hfutils.operate import get_hf_client
import requests
from hfutils.operate import get_hf_client, get_hf_fs
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_normpath, parse_hf_fs_path, hf_fs_path
from huggingface_hub import hf_hub_download, HfFileSystem
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import OfflineModeIsEnabled
from tokenizers import Tokenizer

from imgutils.data import MultiImagesTyping, load_images, ImageTyping
@@ -64,6 +66,9 @@ def _check_gradio_env():
                               f'Please install it with `pip install dghs-imgutils[demo]`.')


_OFFLINE = object()


class CLIPModel:
    """
    Main interface for CLIP model operations.
@@ -124,7 +129,8 @@ class CLIPModel:
        """
        with self._global_lock:
            if self._model_names is None:
                hf_fs = HfFileSystem(token=self._get_hf_token())
                try:
                    hf_fs = get_hf_fs(hf_token=self._get_hf_token())
                    self._model_names = [
                        hf_normpath(os.path.dirname(parse_hf_fs_path(fspath).filename))
                        for fspath in hf_fs.glob(hf_fs_path(
@@ -133,6 +139,12 @@ class CLIPModel:
                            filename='**/image_encode.onnx',
                        ))
                    ]
                except (
                        requests.exceptions.ConnectionError,
                        requests.exceptions.Timeout,
                        OfflineModeIsEnabled,
                ):
                    self._model_names = _OFFLINE

        return self._model_names

@@ -144,7 +156,10 @@ class CLIPModel:
        :type model_name: str
        :raises ValueError: If model name is not found in repository
        """
        if model_name not in self.model_names:
        model_list = self.model_names
        if model_list is _OFFLINE:
            return  # do not check when in offline mode
        if model_name not in model_list:
            raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
                             f'models {self.model_names!r} are available.')

@@ -402,6 +417,7 @@ class CLIPModel:

        Use this to free memory when switching between different model variants.
        """
        self._model_names = None
        self._image_encoders.clear()
        self._image_preprocessors.clear()
        self._text_encoders.clear()
@@ -423,6 +439,10 @@ class CLIPModel:
        """
        _check_gradio_env()
        model_list = self.model_names
        if model_list is _OFFLINE and not default_model_name:
            raise EnvironmentError('You are in OFFLINE mode, '
                                   'you must assign a default model name to make this ui usable.')

        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
            selected_model_name, selected_time = None, None
@@ -454,7 +474,11 @@ class CLIPModel:
                    gr_raw_text = gr.TextArea(value='', lines=5, autoscroll=True, label='Labels',
                                              placeholder='Enter labels, one per line')
                with gr.Row():
                    if model_list is not _OFFLINE:
                        gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                    else:
                        gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model',
                                               interactive=False)

                gr_submit = gr.Button(value='Submit', variant='primary')

+36 −13
Original line number Diff line number Diff line
@@ -20,10 +20,12 @@ from threading import Lock
from typing import List, Union, Optional, Any, Dict

import numpy as np
from hfutils.operate import get_hf_client
import requests
from hfutils.operate import get_hf_client, get_hf_fs
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_normpath, hf_fs_path, parse_hf_fs_path
from huggingface_hub import hf_hub_download, HfFileSystem
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import OfflineModeIsEnabled
from tokenizers import Tokenizer

from ..data import MultiImagesTyping, load_images, ImageTyping
@@ -57,6 +59,9 @@ def _check_gradio_env():
                               f'Please install it with `pip install dghs-imgutils[demo]`.')


_OFFLINE = object()


class SigLIPModel:
    """
    Main class for managing and using SigLIP models.
@@ -109,7 +114,8 @@ class SigLIPModel:
        """
        with self._global_lock:
            if self._model_names is None:
                hf_fs = HfFileSystem(token=self._get_hf_token())
                try:
                    hf_fs = get_hf_fs(hf_token=self._get_hf_token())
                    self._model_names = [
                        hf_normpath(os.path.dirname(parse_hf_fs_path(fspath).filename))
                        for fspath in hf_fs.glob(hf_fs_path(
@@ -118,6 +124,12 @@ class SigLIPModel:
                            filename='**/image_encode.onnx',
                        ))
                    ]
                except (
                        requests.exceptions.ConnectionError,
                        requests.exceptions.Timeout,
                        OfflineModeIsEnabled,
                ):
                    self._model_names = _OFFLINE

        return self._model_names

@@ -129,7 +141,10 @@ class SigLIPModel:
        :type model_name: str
        :raises ValueError: If model name is not found in repository
        """
        if model_name not in self.model_names:
        model_list = self.model_names
        if model_list is _OFFLINE:
            return  # do not check when in offline mode
        if model_name not in model_list:
            raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
                             f'models {self.model_names!r} are available.')

@@ -410,6 +425,10 @@ class SigLIPModel:
        """
        _check_gradio_env()
        model_list = self.model_names
        if model_list is _OFFLINE and not default_model_name:
            raise EnvironmentError('You are in OFFLINE mode, '
                                   'you must assign a default model name to make this ui usable.')

        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
            selected_model_name, selected_time = None, None
@@ -441,7 +460,11 @@ class SigLIPModel:
                    gr_raw_text = gr.TextArea(value='', lines=5, autoscroll=True, label='Labels',
                                              placeholder='Enter labels, one per line')
                with gr.Row():
                    if model_list is not _OFFLINE:
                        gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                    else:
                        gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model',
                                               interactive=False)

                gr_submit = gr.Button(value='Submit', variant='primary')

+76 −28
Original line number Diff line number Diff line
@@ -19,12 +19,15 @@ from threading import Lock
from typing import List, Optional, Tuple, Union

import numpy as np
import requests
from PIL import Image
from hbutils.color import rnd_colors
from hbutils.design import SingletonMark
from hfutils.operate import get_hf_client, get_hf_fs
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import OfflineModeIsEnabled, EntryNotFoundError

from ..data import load_image, rgb_encode, ImageTyping
from ..utils import open_onnx_model, ts_lru_cache
@@ -454,6 +457,9 @@ def _safe_eval_names_str(names_str):
    return result


_OFFLINE = SingletonMark('OFFLINE')


class YOLOModel:
    """
    A class to manage YOLO models from a Hugging Face repository.
@@ -503,12 +509,18 @@ class YOLOModel:
        """
        Get the list of available model names in the repository.

        :return: List of model names.
        This property performs a glob search in the Hugging Face repository to find all ONNX models.
        The search is thread-safe and implements caching to avoid repeated filesystem operations.
        Results are normalized to provide consistent path formats.

        :return: List of available model names in the repository. Returns _OFFLINE list if offline mode is enabled
                or connection errors occur.
        :rtype: List[str]
        """
        with self._global_lock:
            if self._model_names is None:
                hf_fs = HfFileSystem(token=self._get_hf_token())
                try:
                    hf_fs = get_hf_fs(hf_token=self._get_hf_token())
                    self._model_names = [
                        hf_normpath(os.path.dirname(os.path.relpath(item, self.repo_id)))
                        for item in hf_fs.glob(hf_fs_path(
@@ -517,6 +529,12 @@ class YOLOModel:
                            filename='*/model.onnx',
                        ))
                    ]
                except (
                        requests.exceptions.ConnectionError,
                        requests.exceptions.Timeout,
                        OfflineModeIsEnabled,
                ):
                    self._model_names = _OFFLINE

        return self._model_names

@@ -524,11 +542,19 @@ class YOLOModel:
        """
        Check if the given model name is valid for this repository.

        :param model_name: Name of the model to check.
        This method validates model names against the available models in the repository.
        Validation is skipped in offline mode to allow for local operations.

        :param model_name: Name of the model to check against the repository's available models.
        :type model_name: str
        :raises ValueError: If the model name is not found in the repository.
        :raises ValueError: If the model name is not found in the repository and not in offline mode.
                           The error message includes available model names for reference.
        :note: This method is a helper function primarily used internally for model validation.
        """
        if model_name not in self.model_names:
        model_list = self.model_names
        if model_list is _OFFLINE:
            return  # do not check when in offline mode
        if model_name not in model_list:
            raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
                             f'models {self.model_names!r} are available.')

@@ -545,8 +571,9 @@ class YOLOModel:
            if model_name not in self._models:
                self._check_model_name(model_name)
                model = open_onnx_model(hf_hub_download(
                    self.repo_id,
                    f'{model_name}/model.onnx',
                    repo_id=self.repo_id,
                    repo_type='model',
                    filename=f'{model_name}/model.onnx',
                    token=self._get_hf_token(),
                ))
                model_metadata = model.get_modelmeta()
@@ -564,17 +591,20 @@ class YOLOModel:
    def _get_model_type(self, model_name: str):
        with self._model_lock:
            if model_name not in self._model_types:
                hf_fs = get_hf_fs(hf_token=self._get_hf_token())
                fs_path = hf_fs_path(
                try:
                    model_type_file = hf_hub_download(
                        repo_id=self.repo_id,
                        repo_type='model',
                        filename=f'{model_name}/model_type.json',
                        revision='main',
                        token=self._get_hf_token()
                    )
                if hf_fs.exists(fs_path):
                    model_type = json.loads(hf_fs.read_text(fs_path))['model_type']
                else:
                except (EntryNotFoundError,):
                    model_type = 'yolo'
                else:
                    with open(model_type_file, 'r') as f:
                        model_type = json.load(f)['model_type']

                self._model_types[model_name] = model_type

        return self._model_types[model_name]
@@ -640,10 +670,18 @@ class YOLOModel:
        """
        Clear cached model and metadata.

        This method performs a complete cleanup by:

        1. Removing stored model names
        2. Clearing the model cache
        3. Clearing model type information

        This method removes all cached models and their associated metadata from memory.
        It's useful for freeing up memory or ensuring that the latest versions of models are loaded.
        """
        self._model_names = None
        self._models.clear()
        self._model_types.clear()

    def make_ui(self, default_model_name: Optional[str] = None,
                default_conf_threshold: float = 0.25, default_iou_threshold: float = 0.7):
@@ -663,6 +701,7 @@ class YOLOModel:
        :type default_iou_threshold: float

        :raises ImportError: If Gradio is not installed in the environment.
        :raises EnvironmentError: If in OFFLINE mode and no default_model_name is provided.

        :Example:

@@ -671,6 +710,10 @@ class YOLOModel:
        """
        _check_gradio_env()
        model_list = self.model_names
        if model_list is _OFFLINE and not default_model_name:
            raise EnvironmentError('You are in OFFLINE mode, '
                                   'you must assign a default model name to make this ui usable.')

        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
            selected_model_name, selected_time = None, None
@@ -711,7 +754,11 @@ class YOLOModel:
            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    if model_list is not _OFFLINE:
                        gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                    else:
                        gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model',
                                               interactive=False)
                    gr_allow_dynamic = gr.Checkbox(value=False, label='Allow Dynamic Size')
                with gr.Row():
                    gr_iou_threshold = gr.Slider(0.0, 1.0, default_iou_threshold, label='IOU Threshold')
@@ -756,7 +803,8 @@ class YOLOModel:
        :type server_port: Optional[int]
        :param kwargs: Additional keyword arguments to pass to gr.Blocks.launch().

        :raises EnvironmentError: If Gradio is not installed in the environment.
        :raises EnvironmentError: If Gradio is not installed in the environment,
                                  or if in OFFLINE mode and no default_model_name is provided.

        Example:
            >>> model = YOLOModel("username/repo_name")
+1 −1
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ pilmoji>=1.3.0
shapely!=2.0.7
pyclipper
deprecation>=2.0.0
hfutils>=0.4.2
hfutils>=0.8.0
filelock
bchlib>=1.0.0,!=2.0.0,!=2.0.1,!=2.1.0,!=2.1.1,!=2.1.2
piexif
Loading