Commit b257c7d0 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add unittest for siglip

parent dfe5889d
Loading
Loading
Loading
Loading
+36 −13
Original line number Diff line number Diff line
@@ -20,10 +20,12 @@ from threading import Lock
from typing import List, Union, Optional, Any, Dict

import numpy as np
from hfutils.operate import get_hf_client
import requests
from hfutils.operate import get_hf_client, get_hf_fs
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_normpath, hf_fs_path, parse_hf_fs_path
from huggingface_hub import hf_hub_download, HfFileSystem
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import OfflineModeIsEnabled
from tokenizers import Tokenizer

from ..data import MultiImagesTyping, load_images, ImageTyping
@@ -57,6 +59,9 @@ def _check_gradio_env():
                               f'Please install it with `pip install dghs-imgutils[demo]`.')


_OFFLINE = object()


class SigLIPModel:
    """
    Main class for managing and using SigLIP models.
@@ -109,7 +114,8 @@ class SigLIPModel:
        """
        with self._global_lock:
            if self._model_names is None:
                hf_fs = HfFileSystem(token=self._get_hf_token())
                try:
                    hf_fs = get_hf_fs(hf_token=self._get_hf_token())
                    self._model_names = [
                        hf_normpath(os.path.dirname(parse_hf_fs_path(fspath).filename))
                        for fspath in hf_fs.glob(hf_fs_path(
@@ -118,6 +124,12 @@ class SigLIPModel:
                            filename='**/image_encode.onnx',
                        ))
                    ]
                except (
                        requests.exceptions.ConnectionError,
                        requests.exceptions.Timeout,
                        OfflineModeIsEnabled,
                ):
                    self._model_names = _OFFLINE

        return self._model_names

@@ -129,7 +141,10 @@ class SigLIPModel:
        :type model_name: str
        :raises ValueError: If model name is not found in repository
        """
        if model_name not in self.model_names:
        model_list = self.model_names
        if model_list is _OFFLINE:
            return  # do not check when in offline mode
        if model_name not in model_list:
            raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
                             f'models {self.model_names!r} are available.')

@@ -410,6 +425,10 @@ class SigLIPModel:
        """
        _check_gradio_env()
        model_list = self.model_names
        if model_list is _OFFLINE and not default_model_name:
            raise EnvironmentError('You are in OFFLINE mode, '
                                   'you must assign a default model name to make this ui usable.')

        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
            selected_model_name, selected_time = None, None
@@ -441,7 +460,11 @@ class SigLIPModel:
                    gr_raw_text = gr.TextArea(value='', lines=5, autoscroll=True, label='Labels',
                                              placeholder='Enter labels, one per line')
                with gr.Row():
                    if model_list is not _OFFLINE:
                        gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                    else:
                        gr_model = gr.Dropdown([default_model_name], value=default_model_name, label='Model',
                                               interactive=False)

                gr_submit = gr.Button(value='Submit', variant='primary')

+35 −1
Original line number Diff line number Diff line
import re
from unittest.mock import patch

import numpy as np
import pytest
from huggingface_hub.utils import reset_sessions

from imgutils.generic import siglip_image_encode, siglip_text_encode, siglip_predict
from imgutils.generic.siglip import _open_models_for_repo_id
@@ -23,7 +25,18 @@ def _release_model_after_run(siglip_repo_id):
    try:
        yield
    finally:
        _open_models_for_repo_id(siglip_repo_id).clear()
        _open_models_for_repo_id.cache_clear()


@pytest.fixture()
def clean_session():
    reset_sessions()
    _open_models_for_repo_id.cache_clear()
    try:
        yield
    finally:
        reset_sessions()
        _open_models_for_repo_id.cache_clear()


@pytest.mark.unittest
@@ -84,3 +97,24 @@ class TestGenericSiglip:
            [[0.0013782851165160537, 0.27010253071784973, 9.751768811838701e-05, 3.6702780814579228e-09],
             [1.2790776438009743e-08, 4.396981001519862e-09, 3.2838454178119036e-10, 1.0559210750216153e-06]])
        np.testing.assert_allclose(result, expected_result, atol=3e-4)

    @patch("huggingface_hub.constants.HF_HUB_OFFLINE", True)
    def test_siglip_predict_with_offline_mode(self, siglip_repo_id, siglip_model_name, clean_session):
        result = siglip_predict(
            images=[
                get_testfile('clip_cats.jpg'),
                get_testfile('idolsankaku', '3.jpg'),
            ],
            texts=[
                'a photo of a cat',
                'a photo of 2 cats',
                'a photo of 2 dogs',
                'a photo of a woman',
            ],
            repo_id=siglip_repo_id,
            model_name=siglip_model_name,
        )
        expected_result = np.array(
            [[0.0013782851165160537, 0.27010253071784973, 9.751768811838701e-05, 3.6702780814579228e-09],
             [1.2790776438009743e-08, 4.396981001519862e-09, 3.2838454178119036e-10, 1.0559210750216153e-06]])
        np.testing.assert_allclose(result, expected_result, atol=3e-4)