Commit 748d1d34 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add bit preprocessor

parent 44e38ed7
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ Supported Processors:
    .. include:: transformers_supported.demo.py.txt
"""
from .base import register_creators_for_transformers, NotProcessorTypeError, create_transforms_from_transformers
from .bit import create_bit_transforms, create_transforms_from_bit_processor
from .clip import create_clip_transforms, create_transforms_from_clip_processor
from .convnext import create_convnext_transforms, create_transforms_from_convnext_processor
from .siglip import create_siglip_transforms, create_transforms_from_siglip_processor
+87 −0
Original line number Diff line number Diff line
from PIL import Image

from .base import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, _DEFAULT, register_creators_for_transformers, _check_transformers, \
    NotProcessorTypeError
from ..pillow import PillowConvertRGB, PillowResize, PillowCenterCrop, PillowToTensor, PillowNormalize, PillowCompose, \
    PillowRescale

_DEFAULT_SIZE = {"shortest_edge": 224}
_DEFAULT_CROP_SIZE = {"height": 224, "width": 224}


def create_bit_transforms(
        do_resize: bool = True,
        size=_DEFAULT,
        resample=Image.BICUBIC,
        do_center_crop: bool = True,
        crop_size=_DEFAULT,
        do_rescale: bool = True,
        rescale_factor: float = 1 / 255,
        do_normalize: bool = True,
        image_mean=_DEFAULT,
        image_std=_DEFAULT,
        do_convert_rgb: bool = True,
):
    # Set default values
    size = size if size is not _DEFAULT else _DEFAULT_SIZE
    crop_size = crop_size if crop_size is not _DEFAULT else _DEFAULT_CROP_SIZE
    image_mean = image_mean if image_mean is not _DEFAULT else OPENAI_CLIP_MEAN
    image_std = image_std if image_std is not _DEFAULT else OPENAI_CLIP_STD

    transform_list = []

    # Convert to RGB
    if do_convert_rgb:
        transform_list.append(PillowConvertRGB())

    # Resize
    if do_resize:
        if "shortest_edge" in size:
            transform_list.append(PillowResize(size["shortest_edge"], interpolation=resample))
        elif "height" in size and "width" in size:
            transform_list.append(PillowResize((size["height"], size["width"]), interpolation=resample))
        else:
            raise ValueError(f'Unknown size configuration - {size!r}.')  # pragma: no cover

    # Center crop
    if do_center_crop:
        transform_list.append(PillowCenterCrop((crop_size["height"], crop_size["width"])))

    # Convert to tensor (implicitly scales to [0,1])
    transform_list.append(PillowToTensor())

    # Rescale
    if do_rescale and rescale_factor != 1 / 255:
        transform_list.append(PillowRescale(rescale_factor * 255))

    # Normalize
    if do_normalize:
        transform_list.append(PillowNormalize(mean=image_mean, std=image_std))

    return PillowCompose(transform_list)


@register_creators_for_transformers()
def create_transforms_from_bit_processor(processor):
    _check_transformers()
    from transformers import BitImageProcessor

    if isinstance(processor, BitImageProcessor):
        pass
    else:
        raise NotProcessorTypeError(f'Unknown Bit processor - {processor!r}.')
    processor: BitImageProcessor

    return create_bit_transforms(
        do_resize=processor.do_resize,
        size=processor.size,
        resample=processor.resample,
        do_center_crop=processor.do_center_crop,
        crop_size=processor.crop_size,
        do_rescale=processor.do_rescale,
        rescale_factor=processor.rescale_factor,
        do_normalize=processor.do_normalize,
        image_mean=processor.image_mean,
        image_std=processor.image_std,
        do_convert_rgb=processor.do_convert_rgb,
    )
+77 −0
Original line number Diff line number Diff line
from unittest import skipUnless

import numpy as np
import pytest
from hbutils.testing import tmatrix

from imgutils.data import load_image
from imgutils.preprocess.transformers import create_transforms_from_transformers
from test.testings import get_testfile

try:
    import transformers
except (ImportError, ModuleNotFoundError):
    _HAS_TRANSFORMERS = False
else:
    _HAS_TRANSFORMERS = True


@pytest.mark.unittest
class TestPreprocessTransformersBit:
    @skipUnless(_HAS_TRANSFORMERS, 'Transformers required.')
    @pytest.mark.parametrize(*tmatrix({
        'repo_id': [
            'facebook/dinov2-small-imagenet1k-1-layer',
            'robertsw/aesthetics_v2',
            'facebook/hiera-huge-224-mae-hf',
            'Kalinga/dinov2-base-finetuned-oxford',
            'facebook/dinov2-with-registers-giant',
            'facebook/hiera-base-224-in1k-hf',
            'zkatona/dinov2-base-finetuned-oxford',
            'microsoft/focalnet-small',
            'atuo/vit-base-patch16-224-in21k-finetuned-crop-classification',
            'facebook/hiera-tiny-224-in1k-hf',
            'facebook/hiera-tiny-224-mae-hf',
            'microsoft/rad-dino',
            'facebook/dinov2-with-registers-base',
            'NamLe12/vit-base-beans',
            'suncy13/Foot',

            'facebook/dinov2-base',
            'facebook/dinov2-large',
            'facebook/dinov2-giant',
            'facebook/dinov2-small',
            'facebook/dinov2-base-imagenet1k-1-layer',
            'facebook/dinov2-with-registers-giant',
            'facebook/dinov2-with-registers-small',
            'facebook/dinov2-giant-imagenet1k-1-layer',
            'facebook/dinov2-with-registers-large',
            'facebook/dinov2-with-registers-base',
            'facebook/dinov2-large-imagenet1k-1-layer',
            'facebook/dinov2-small-imagenet1k-1-layer',
            'facebook/dinov2-with-registers-giant-imagenet1k-1-layer',
            'facebook/dinov2-with-registers-base-imagenet1k-1-layer',
            'facebook/dinov2-with-registers-large-imagenet1k-1-layer',
            'facebook/dinov2-with-registers-small-imagenet1k-1-layer',
        ],
        'src_image': [
            'png_640.png',
            'png_640_m90.png',
            'nude_girl.png',
            'dori_640.png',
            'nian_640.png',
        ]
    }))
    def test_bit_image_preprocess_align(self, src_image, repo_id):
        from transformers import AutoImageProcessor
        image = load_image(get_testfile(src_image), mode='RGB', force_background='white')
        processor = AutoImageProcessor.from_pretrained(repo_id)

        trans = create_transforms_from_transformers(processor)

        expected_output = processor.preprocess(image)['pixel_values'][0]
        output = trans(image)
        np.testing.assert_array_almost_equal(
            output,
            expected_output,
        )