Commit a0fb2e7f authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add siglip

parent ee904dbf
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -71,3 +71,18 @@ create_transforms_from_vit_processor



create_siglip_transforms
--------------------------------------------------------------------

.. autofunction:: create_siglip_transforms



create_transforms_from_siglip_processor
--------------------------------------------------------------------

.. autofunction:: create_transforms_from_siglip_processor



+1 −0
Original line number Diff line number Diff line
@@ -9,4 +9,5 @@ Supported Processors:
from .base import register_creators_for_transformers, NotProcessorTypeError, create_transforms_from_transformers
from .clip import create_clip_transforms, create_transforms_from_clip_processor
from .convnext import create_convnext_transforms, create_transforms_from_convnext_processor
from .siglip import create_siglip_transforms, create_transforms_from_siglip_processor
from .vit import create_vit_transforms, create_transforms_from_vit_processor
+71 −0
Original line number Diff line number Diff line
from PIL import Image

from .base import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, _DEFAULT, _check_transformers, NotProcessorTypeError, \
    register_creators_for_transformers
from ..pillow import PillowCompose, PillowNormalize, PillowRescale, PillowToTensor, PillowResize, PillowConvertRGB

_DEFAULT_SIZE = {"height": 224, "width": 224}


def create_siglip_transforms(
        do_resize: bool = True,
        size=_DEFAULT,
        resample: int = Image.BICUBIC,
        do_rescale: bool = True,
        rescale_factor: float = 1 / 255,
        do_normalize: bool = True,
        image_mean=_DEFAULT,
        image_std=_DEFAULT,
        do_convert_rgb: bool = True,
):
    # Set default values
    size = size if size is not _DEFAULT else _DEFAULT_SIZE
    image_mean = image_mean if image_mean is not _DEFAULT else IMAGENET_STANDARD_MEAN
    image_std = image_std if image_std is not _DEFAULT else IMAGENET_STANDARD_STD

    transforms_list = []

    # Convert to RGB
    if do_convert_rgb:
        transforms_list.append(PillowConvertRGB())

    # Resize
    if do_resize:
        transforms_list.append(PillowResize((size["height"], size["width"]), interpolation=resample))

    # Convert to tensor (implicitly rescales to 0-1)
    transforms_list.append(PillowToTensor())

    # Rescale if needed (only if different from 1/255)
    if do_rescale and rescale_factor != 1 / 255:
        transforms_list.append(PillowRescale(rescale_factor * 255))

    # Normalize
    if do_normalize:
        transforms_list.append(PillowNormalize(mean=image_mean, std=image_std))

    return PillowCompose(transforms_list)


@register_creators_for_transformers()
def create_transforms_from_siglip_processor(processor):
    _check_transformers()
    from transformers import SiglipImageProcessor

    if isinstance(processor, SiglipImageProcessor):
        pass
    else:
        raise NotProcessorTypeError(f'Unknown Siglip processor - {processor!r}.')
    processor: SiglipImageProcessor

    return create_siglip_transforms(
        do_resize=processor.do_resize,
        size=processor.size,
        resample=processor.resample,
        do_rescale=processor.do_rescale,
        rescale_factor=processor.rescale_factor,
        do_normalize=processor.do_normalize,
        image_mean=processor.image_mean,
        image_std=processor.image_std,
        do_convert_rgb=processor.do_convert_rgb,
    )
+60 −0
Original line number Diff line number Diff line
from unittest import skipUnless

import numpy as np
import pytest
from hbutils.testing import tmatrix

from imgutils.data import load_image
from imgutils.preprocess.transformers import create_transforms_from_transformers
from test.testings import get_testfile

try:
    import transformers
except (ImportError, ModuleNotFoundError):
    _HAS_TRANSFORMERS = False
else:
    _HAS_TRANSFORMERS = True


@pytest.mark.unittest
class TestPreprocessTransformersSiglip:
    @skipUnless(_HAS_TRANSFORMERS, 'Transformers required.')
    @pytest.mark.parametrize(*tmatrix({
        'repo_id': [
            'Marqo/marqo-ecommerce-embeddings-B',
            'ucsahin/TraVisionLM-DPO',
            'google/siglip-base-patch16-384',
            'google/siglip-base-patch16-512',
            'llava-hf/llava-interleave-qwen-0.5b-hf',
            'zhumj34/Mipha-3B',
            'google/siglip-so400m-patch14-384',
            'lmms-lab/llava-onevision-qwen2-72b-ov-sft',
            'p1atdev/siglip-tagger-test-3',
            'gokaygokay/paligemma-rich-captions',
            'lmms-lab/llava-onevision-qwen2-0.5b-ov',
            'gokaygokay/sd3-long-captioner-v2',
            'OpenFace-CQUPT/Human_LLaVA',
            'ucsahin/TraVisionLM-base',
            'mlx-community/paligemma-3b-mix-448-8bit',
        ],
        'src_image': [
            'png_640.png',
            'png_640_m90.png',
            'nude_girl.png',
            'dori_640.png',
            'nian_640.png',
        ]
    }))
    def test_siglip_image_preprocess_align(self, src_image, repo_id):
        from transformers import AutoImageProcessor
        image = load_image(get_testfile(src_image), mode='RGB', force_background='white')
        processor = AutoImageProcessor.from_pretrained(repo_id)

        trans = create_transforms_from_transformers(processor)

        expected_output = processor.preprocess(image)['pixel_values'][0]
        output = trans(image)
        np.testing.assert_array_almost_equal(
            output,
            expected_output,
        )
+1 −1
Original line number Diff line number Diff line
@@ -38,7 +38,7 @@ class TestPreprocessTransformersViT:
            'nian_640.png',
        ]
    }))
    def test_convnext_image_preprocess_align(self, src_image, repo_id):
    def test_vit_image_preprocess_align(self, src_image, repo_id):
        from transformers import AutoImageProcessor
        image = load_image(get_testfile(src_image), mode='RGB', force_background='white')
        processor = AutoImageProcessor.from_pretrained(repo_id)