Commit dfe5889d authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): test for clip

parent 4214928c
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -19,10 +19,10 @@ from typing import Tuple, Optional, List, Dict, Callable
import numpy as np
import requests
from PIL import Image
from hfutils.operate import get_hf_client
from hfutils.operate import get_hf_client, get_hf_fs
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import hf_hub_download, HfFileSystem
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError, OfflineModeIsEnabled

from ..data import rgb_encode, ImageTyping, load_image
@@ -184,7 +184,7 @@ class ClassifyModel:
        with self._global_lock:
            if self._model_names is None:
                try:
                    hf_fs = HfFileSystem(token=self._get_hf_token())
                    hf_fs = get_hf_fs(hf_token=self._get_hf_token())
                    self._model_names = [
                        hf_normpath(os.path.dirname(os.path.relpath(item, self.repo_id)))
                        for item in hf_fs.glob(hf_fs_path(
@@ -449,6 +449,7 @@ class ClassifyModel:

        This method frees up memory by removing all loaded models and labels from the cache.
        """
        self._model_names = None
        self._models.clear()
        self._labels.clear()
        self._preprocesses.clear()
+37 −13
Original line number Diff line number Diff line
@@ -27,10 +27,12 @@ from threading import Lock
from typing import List, Union, Optional, Any, Dict

import numpy as np
from hfutils.operate import get_hf_client
import requests
from hfutils.operate import get_hf_client, get_hf_fs
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_normpath, parse_hf_fs_path, hf_fs_path
from huggingface_hub import hf_hub_download, HfFileSystem
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import OfflineModeIsEnabled
from tokenizers import Tokenizer

from imgutils.data import MultiImagesTyping, load_images, ImageTyping
@@ -64,6 +66,9 @@ def _check_gradio_env():
                               f'Please install it with `pip install dghs-imgutils[demo]`.')


_OFFLINE = object()


class CLIPModel:
    """
    Main interface for CLIP model operations.
@@ -124,7 +129,8 @@ class CLIPModel:
        """
        with self._global_lock:
            if self._model_names is None:
                hf_fs = HfFileSystem(token=self._get_hf_token())
                try:
                    hf_fs = get_hf_fs(hf_token=self._get_hf_token())
                    self._model_names = [
                        hf_normpath(os.path.dirname(parse_hf_fs_path(fspath).filename))
                        for fspath in hf_fs.glob(hf_fs_path(
@@ -133,6 +139,12 @@ class CLIPModel:
                            filename='**/image_encode.onnx',
                        ))
                    ]
                except (
                        requests.exceptions.ConnectionError,
                        requests.exceptions.Timeout,
                        OfflineModeIsEnabled,
                ):
                    self._model_names = _OFFLINE

        return self._model_names

@@ -144,7 +156,10 @@ class CLIPModel:
        :type model_name: str
        :raises ValueError: If model name is not found in repository
        """
        if model_name not in self.model_names:
        model_list = self.model_names
        if model_list is _OFFLINE:
            return  # do not check when in offline mode
        if model_name not in model_list:
            raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
                             f'models {self.model_names!r} are available.')

@@ -402,6 +417,7 @@ class CLIPModel:

        Use this to free memory when switching between different model variants.
        """
        self._model_names = None
        self._image_encoders.clear()
        self._image_preprocessors.clear()
        self._text_encoders.clear()
@@ -423,6 +439,10 @@ class CLIPModel:
        """
        _check_gradio_env()
        model_list = self.model_names
        if model_list is _OFFLINE and not default_model_name:
            raise EnvironmentError('You are in OFFLINE mode, '
                                   'you must assign a default model name to make this ui usable.')

        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
            selected_model_name, selected_time = None, None
@@ -454,7 +474,11 @@ class CLIPModel:
                    gr_raw_text = gr.TextArea(value='', lines=5, autoscroll=True, label='Labels',
                                              placeholder='Enter labels, one per line')
                with gr.Row():
                    if model_list is not _OFFLINE:
                        gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                    else:
                        gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model',
                                               interactive=False)

                gr_submit = gr.Button(value='Submit', variant='primary')

+3 −3
Original line number Diff line number Diff line
@@ -23,10 +23,10 @@ import requests
from PIL import Image
from hbutils.color import rnd_colors
from hbutils.design import SingletonMark
from hfutils.operate import get_hf_client
from hfutils.operate import get_hf_client, get_hf_fs
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import HfFileSystem, hf_hub_download
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import OfflineModeIsEnabled, EntryNotFoundError

from ..data import load_image, rgb_encode, ImageTyping
@@ -520,7 +520,7 @@ class YOLOModel:
        with self._global_lock:
            if self._model_names is None:
                try:
                    hf_fs = HfFileSystem(token=self._get_hf_token())
                    hf_fs = get_hf_fs(hf_token=self._get_hf_token())
                    self._model_names = [
                        hf_normpath(os.path.dirname(os.path.relpath(item, self.repo_id)))
                        for item in hf_fs.glob(hf_fs_path(
+1 −1
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ pilmoji>=1.3.0
shapely!=2.0.7
pyclipper
deprecation>=2.0.0
hfutils>=0.4.2
hfutils>=0.8.0
filelock
bchlib>=1.0.0,!=2.0.0,!=2.0.1,!=2.1.0,!=2.1.1,!=2.1.2
piexif
+33 −1
Original line number Diff line number Diff line
import re
from unittest.mock import patch

import numpy as np
import pytest
from huggingface_hub.utils import reset_sessions

from imgutils.generic.clip import _open_models_for_repo_id, clip_image_encode, clip_text_encode, clip_predict
from test.testings import get_testfile
@@ -22,7 +24,18 @@ def _release_model_after_run(clip_repo_id):
    try:
        yield
    finally:
        _open_models_for_repo_id(clip_repo_id).clear()
        _open_models_for_repo_id.cache_clear()


@pytest.fixture()
def clean_session():
    reset_sessions()
    _open_models_for_repo_id.cache_clear()
    try:
        yield
    finally:
        reset_sessions()
        _open_models_for_repo_id.cache_clear()


@pytest.mark.unittest
@@ -81,3 +94,22 @@ class TestGenericCLIP:
        expected_result = np.array([[0.9803991317749023, 0.005067288409918547, 0.01453354675322771],
                                    [0.21404513716697693, 0.049479320645332336, 0.7364755272865295]])
        np.testing.assert_allclose(result, expected_result, atol=3e-4)

    @patch("huggingface_hub.constants.HF_HUB_OFFLINE", True)
    def test_clip_predict_with_offline_mode(self, clip_repo_id, clip_model_name, clean_session):
        result = clip_predict(
            images=[
                get_testfile('clip_cats.jpg'),
                get_testfile('idolsankaku', '3.jpg'),
            ],
            texts=[
                'a photo of a cat',
                'a photo of a dog',
                'a photo of a human',
            ],
            repo_id=clip_repo_id,
            model_name=clip_model_name,
        )
        expected_result = np.array([[0.9803991317749023, 0.005067288409918547, 0.01453354675322771],
                                    [0.21404513716697693, 0.049479320645332336, 0.7364755272865295]])
        np.testing.assert_allclose(result, expected_result, atol=3e-4)