Commit d73611dc authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add siglip unittest

parent 1d221bf3
Loading
Loading
Loading
Loading
+6 −5
Original line number Diff line number Diff line
@@ -50,13 +50,14 @@ pdocs:
dataset:
	mkdir -p ${DATASET_DIR}
	if [ ! -d ${DATASET_DIR}/chafen_arknights ]; then \
		git clone https://huggingface.co/datasets/deepghs/chafen_arknights.git ${DATASET_DIR}/chafen_arknights; \
		hfutils download -r deepghs/chafen_arknights -t dataset -d . -o ${DATASET_DIR}/chafen_arknights; \
	fi
	if [ ! -d ${DATASET_DIR}/monochrome_danbooru ]; then \
		git clone https://huggingface.co/datasets/deepghs/monochrome_danbooru.git ${DATASET_DIR}/monochrome_danbooru; \
		hfutils download -r deepghs/monochrome_danbooru -t dataset -d . -o ${DATASET_DIR}/monochrome_danbooru; \
	fi
	if [ ! -d ${DATASET_DIR}/images_test_v1 ]; then \
		mkdir -p ${DATASET_DIR}/images_test_v1 && \
		curl -L -o ${DATASET_DIR}/images_test_v1/images_test_v1.tar.xz https://huggingface.co/datasets/deepghs/character_similarity/resolve/main/images_test_v1.tar.xz && \
		cd ${DATASET_DIR}/images_test_v1 && tar -xvf images_test_v1.tar.xz && rm -rf *.tar.xz; \
		hfutils download -r deepghs/character_similarity -t dataset -a images_test_v1.tar.xz -o ${DATASET_DIR}/images_test_v1; \
	fi
	if [ ! -d ${DATASET_DIR}/unsplash_1000 ]; then \
		hfutils download -r deepghs/realutils_unittest -a unsplash_1000.zip -o ${DATASET_DIR}/unsplash_1000; \
	fi
+1 −0
Original line number Diff line number Diff line
@@ -4,4 +4,5 @@ Overview:
"""
from .classify import *
from .enhance import *
from .siglip import *
from .yolo import *
+102 −16
Original line number Diff line number Diff line
@@ -10,7 +10,14 @@ from tokenizers import Tokenizer

from ..data import MultiImagesTyping, load_images
from ..preprocess import create_pillow_transforms
from ..utils import open_onnx_model, vreplace, sigmoid
from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache

__all__ = [
    'SigLIPModel',
    'siglip_image_encode',
    'siglip_text_encode',
    'siglip_predict',
]


class SigLIPModel:
@@ -90,6 +97,7 @@ class SigLIPModel:
        """
        with self._model_lock:
            if model_name not in self._image_encoders:
                self._check_model_name(model_name)
                self._image_encoders[model_name] = open_onnx_model(hf_hub_download(
                    repo_id=self.repo_id,
                    repo_type='model',
@@ -109,6 +117,7 @@ class SigLIPModel:
        """
        with self._model_lock:
            if model_name not in self._image_preprocessors:
                self._check_model_name(model_name)
                with open(hf_hub_download(
                        repo_id=self.repo_id,
                        repo_type='model',
@@ -129,6 +138,7 @@ class SigLIPModel:
        """
        with self._model_lock:
            if model_name not in self._text_encoders:
                self._check_model_name(model_name)
                self._text_encoders[model_name] = open_onnx_model(hf_hub_download(
                    repo_id=self.repo_id,
                    repo_type='model',
@@ -148,6 +158,7 @@ class SigLIPModel:
        """
        with self._model_lock:
            if model_name not in self._text_tokenizers:
                self._check_model_name(model_name)
                self._text_tokenizers[model_name] = Tokenizer.from_file(hf_hub_download(
                    repo_id=self.repo_id,
                    repo_type='model',
@@ -167,6 +178,7 @@ class SigLIPModel:
        """
        with self._model_lock:
            if model_name not in self._logit_scales:
                self._check_model_name(model_name)
                with open(hf_hub_download(
                        repo_id=self.repo_id,
                        repo_type='model',
@@ -177,7 +189,19 @@ class SigLIPModel:

        return self._logit_scales[model_name]

    def get_siglip_image_embedding(self, images: MultiImagesTyping, model_name: str, fmt: Any = 'embeddings'):
    def _get_siglip_image_embedding(self, images: MultiImagesTyping, model_name: str, fmt: Any = 'embeddings'):
        preprocessor = self._open_image_preprocessor(model_name)
        model = self._open_image_encoder(model_name)

        images = load_images(images, mode='RGB', force_background='white')
        input_ = np.stack([preprocessor(image) for image in images])
        encodings, embeddings = model.run(['encodings', 'embeddings'], {'pixel_values': input_})
        return vreplace(fmt, {
            'encodings': encodings,
            'embeddings': embeddings,
        })

    def image_encode(self, images: MultiImagesTyping, model_name: str, fmt: Any = 'embeddings'):
        """
        Generate embeddings for input images using the SigLIP model.
    
@@ -189,18 +213,13 @@ class SigLIPModel:
    
        :return: Image embeddings or encodings based on fmt parameter
        """
        preprocessor = self._open_image_preprocessor(model_name)
        model = self._open_image_encoder(model_name)

        images = load_images(images, mode='RGB', force_background='white')
        input_ = np.stack([preprocessor(image) for image in images])
        encodings, embeddings = model.run(['encodings', 'embeddings'], {'pixel_values': input_})
        return vreplace(fmt, {
            'encodings': encodings,
            'embeddings': embeddings,
        })
        return self._get_siglip_image_embedding(
            images=images,
            model_name=model_name,
            fmt=fmt,
        )

    def get_siglip_text_embedding(self, texts: Union[str, List[str]], model_name: str, fmt: Any = 'embeddings'):
    def _get_siglip_text_embedding(self, texts: Union[str, List[str]], model_name: str, fmt: Any = 'embeddings'):
        """
        Generate embeddings for input texts using the SigLIP model.
    
@@ -227,7 +246,25 @@ class SigLIPModel:
            'embeddings': embeddings,
        })

    def classify_with_siglip(
    def text_encode(self, texts: Union[str, List[str]], model_name: str, fmt: Any = 'embeddings'):
        """
        Generate embeddings for input texts using the SigLIP model.

        :param texts: Input text or list of texts
        :type texts: Union[str, List[str]]
        :param model_name: Name of the SigLIP model variant to use
        :type model_name: str
        :param fmt: Output format, either 'encodings' or 'embeddings'

        :return: Text embeddings or encodings based on fmt parameter
        """
        return self._get_siglip_text_embedding(
            texts=texts,
            model_name=model_name,
            fmt=fmt,
        )

    def predict(
            self,
            images: Union[MultiImagesTyping, np.ndarray],
            texts: Union[List[str], str, np.ndarray],
@@ -250,7 +287,7 @@ class SigLIPModel:
        extra_values = {}
        if not isinstance(images, np.ndarray):
            image_embeddings, image_encodings = \
                self.get_siglip_image_embedding(images, model_name=model_name, fmt=('embeddings', 'encodings'))
                self._get_siglip_image_embedding(images, model_name=model_name, fmt=('embeddings', 'encodings'))
            extra_values['image_embeddings'] = image_embeddings
            extra_values['image_encodings'] = image_encodings
            images = image_embeddings
@@ -258,7 +295,7 @@ class SigLIPModel:

        if not isinstance(texts, np.ndarray):
            text_embeddings, text_encodings = \
                self.get_siglip_text_embedding(texts, model_name=model_name, fmt=('embeddings', 'encodings'))
                self._get_siglip_text_embedding(texts, model_name=model_name, fmt=('embeddings', 'encodings'))
            extra_values['text_embeddings'] = text_embeddings
            extra_values['text_encodings'] = text_encodings
            texts = text_embeddings
@@ -275,3 +312,52 @@ class SigLIPModel:
            'predictions': predictions,
            **extra_values,
        })

    def clear(self):
        self._image_encoders.clear()
        self._image_preprocessors.clear()
        self._text_encoders.clear()
        self._text_tokenizers.clear()
        self._logit_scales.clear()


@ts_lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> SigLIPModel:
    return SigLIPModel(repo_id, hf_token=hf_token)


def siglip_image_encode(images: MultiImagesTyping, repo_id: str, model_name: str,
                        fmt: Any = 'embeddings', hf_token: Optional[str] = None):
    model = _open_models_for_repo_id(repo_id, hf_token=hf_token)
    return model.image_encode(
        images=images,
        model_name=model_name,
        fmt=fmt,
    )


def siglip_text_encode(texts: Union[str, List[str]], repo_id: str, model_name: str,
                       fmt: Any = 'embeddings', hf_token: Optional[str] = None):
    model = _open_models_for_repo_id(repo_id, hf_token=hf_token)
    return model.text_encode(
        texts=texts,
        model_name=model_name,
        fmt=fmt,
    )


def siglip_predict(
        images: Union[MultiImagesTyping, np.ndarray],
        texts: Union[List[str], str, np.ndarray],
        repo_id: str,
        model_name: str,
        fmt: Any = 'predictions',
        hf_token: Optional[str] = None,
):
    model = _open_models_for_repo_id(repo_id, hf_token=hf_token)
    return model.predict(
        images=images,
        texts=texts,
        model_name=model_name,
        fmt=fmt
    )
+86 −0
Original line number Diff line number Diff line
import re

import numpy as np
import pytest

from imgutils.generic import siglip_image_encode, siglip_text_encode, siglip_predict
from imgutils.generic.siglip import _open_models_for_repo_id
from test.testings import get_testfile


@pytest.fixture(scope='module')
def siglip_repo_id():
    return 'deepghs/siglip_onnx'


@pytest.fixture(scope='module')
def siglip_model_name():
    return 'google/siglip-base-patch16-256-multilingual'


@pytest.fixture(scope='module', autouse=True)
def _release_model_after_run(siglip_repo_id):
    try:
        yield
    finally:
        _open_models_for_repo_id(siglip_repo_id).clear()


@pytest.mark.unittest
class TestGenericSiglip:
    @pytest.mark.parametrize(['name'], [
        ('unsplash_sZzmhn2xjQY',),
        ('unsplash_S-8ntPEsSwo',),
        ('unsplash_tB4-ftQ4zyI',),
        ('unsplash_l6KamCXeB4U',),
        ('unsplash__9dAwWA4LD8',),
        ('unsplash_LlsAieNJE70',),
        ('unsplash_HWIOLU7_O6w',),
        ('unsplash_1AAa78W_Ezc',),
        ('unsplash_0TPmrjTXjSs',),
        ('unsplash_0yAVtZiYkJY',)
    ])
    def test_siglip_image_encode(self, name, siglip_repo_id, siglip_model_name):
        src_image = get_testfile('dataset', 'unsplash_1000', f'{name}.jpg')
        dst_npy = get_testfile('siglip', 'unsplash_1000', f'{name}.npy')
        embedding = siglip_image_encode(src_image, repo_id=siglip_repo_id, model_name=siglip_model_name)
        expected_embedding = np.load(dst_npy)
        np.testing.assert_allclose(embedding, expected_embedding, rtol=1e-03, atol=1e-05)

    @pytest.mark.parametrize(['text'], [
        ("a red car parked on the street",),
        ("beautiful sunset over mountain landscape",),
        ("two cats playing with yarn",),
        ("fresh fruits in a wooden bowl",),
        ("person reading book under tree",),
        ("colorful hot air balloon in blue sky",),
        ("children playing soccer in the park",),
        ("rustic cabin surrounded by pine trees",),
        ("waves crashing on sandy beach",),
        ("chef cooking in modern kitchen",),
    ])
    def test_siglip_text_encode(self, text, siglip_repo_id, siglip_model_name):
        dst_npy = get_testfile('siglip', 'text', re.sub(r'[\W_]+', '_', text).strip('_') + '.npy')
        embedding = siglip_text_encode(text, repo_id=siglip_repo_id, model_name=siglip_model_name)
        expected_embedding = np.load(dst_npy)
        np.testing.assert_allclose(embedding, expected_embedding, rtol=1e-03, atol=1e-05)

    def test_siglip_predict(self, siglip_repo_id, siglip_model_name):
        result = siglip_predict(
            images=[
                get_testfile('clip_cats.jpg'),
                get_testfile('idolsankaku', '3.jpg'),
            ],
            texts=[
                'a photo of a cat',
                'a photo of 2 cats',
                'a photo of 2 dogs',
                'a photo of a woman',
            ],
            repo_id=siglip_repo_id,
            model_name=siglip_model_name,
        )
        expected_result = np.array(
            [[0.0013782851165160537, 0.27010253071784973, 9.751768811838701e-05, 3.6702780814579228e-09],
             [1.2790776438009743e-08, 4.396981001519862e-09, 3.2838454178119036e-10, 1.0559210750216153e-06]])
        np.testing.assert_allclose(result, expected_result, atol=3e-4)
+2.13 KiB

File added.

No diff preview for this file type.

Loading