Commit 7ff62321 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add support for vit

parent 669ddcd8
Loading
Loading
Loading
Loading
+14 −0
Original line number Diff line number Diff line
@@ -57,3 +57,17 @@ create_transforms_from_convnext_processor



create_vit_transforms
--------------------------------------------------------------------

.. autofunction:: create_vit_transforms



create_transforms_from_vit_processor
--------------------------------------------------------------------

.. autofunction:: create_transforms_from_vit_processor


+1 −0
Original line number Diff line number Diff line
@@ -9,3 +9,4 @@ Supported Processors:
from .base import register_creators_for_transformers, NotProcessorTypeError, create_transforms_from_transformers
from .clip import create_clip_transforms, create_transforms_from_clip_processor
from .convnext import create_convnext_transforms, create_transforms_from_convnext_processor
from .vit import create_vit_transforms, create_transforms_from_vit_processor
+1 −1
Original line number Diff line number Diff line
@@ -110,7 +110,7 @@ def create_transforms_from_convnext_processor(processor):
    if isinstance(processor, ConvNextImageProcessor):
        pass
    else:
        raise NotProcessorTypeError(f'Unknown CLIP processor - {processor!r}.')
        raise NotProcessorTypeError(f'Unknown ConvNext processor - {processor!r}.')
    processor: ConvNextImageProcessor

    return create_convnext_transforms(
+71 −0
Original line number Diff line number Diff line
from PIL import Image

from .base import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, _DEFAULT, register_creators_for_transformers, \
    _check_transformers, NotProcessorTypeError
from ..pillow import PillowRescale, PillowResize, PillowToTensor, PillowNormalize, PillowCompose

_DEFAULT_SIZE = {"height": 224, "width": 224}


def create_vit_transforms(
        do_resize: bool = True,
        size=_DEFAULT,
        resample: int = Image.BILINEAR,
        do_rescale: bool = True,
        rescale_factor: float = 1 / 255,
        do_normalize: bool = True,
        image_mean=_DEFAULT,
        image_std=_DEFAULT,
):
    # Initialize default values
    size = size if size is not _DEFAULT else _DEFAULT_SIZE
    image_mean = image_mean if image_mean is not _DEFAULT else IMAGENET_DEFAULT_MEAN
    image_std = image_std if image_std is not _DEFAULT else IMAGENET_DEFAULT_STD

    transform_list = []

    # Add resize transform if enabled
    if do_resize:
        transform_list.append(
            PillowResize(
                (size["height"], size["width"]),
                interpolation=resample
            )
        )

    # Convert to tensor (always needed)
    transform_list.append(PillowToTensor())

    # Add rescaling if enabled
    # Note: ToTensor already scales to [0,1], so we only need additional scaling if factor != 1/255
    if do_rescale and rescale_factor != 1 / 255:
        transform_list.append(PillowRescale(rescale_factor * 255))

    # Add normalization if enabled0
    if do_normalize:
        transform_list.append(PillowNormalize(mean=image_mean, std=image_std))

    return PillowCompose(transform_list)


@register_creators_for_transformers()
def create_transforms_from_vit_processor(processor):
    _check_transformers()
    from transformers import ViTImageProcessor

    if isinstance(processor, ViTImageProcessor):
        pass
    else:
        raise NotProcessorTypeError(f'Unknown ViT processor - {processor!r}.')
    processor: ViTImageProcessor

    return create_vit_transforms(
        do_resize=processor.do_resize,
        size=processor.size,
        resample=processor.resample,
        do_rescale=processor.do_rescale,
        rescale_factor=processor.rescale_factor,
        do_normalize=processor.do_normalize,
        image_mean=processor.image_mean,
        image_std=processor.image_std,
    )
+53 −0
Original line number Diff line number Diff line
from unittest import skipUnless

import numpy as np
import pytest
from hbutils.testing import tmatrix

from imgutils.data import load_image
from imgutils.preprocess.transformers import create_transforms_from_transformers
from test.testings import get_testfile

try:
    import transformers
except (ImportError, ModuleNotFoundError):
    _HAS_TRANSFORMERS = False
else:
    _HAS_TRANSFORMERS = True


@pytest.mark.unittest
class TestPreprocessTransformersViT:
    @skipUnless(_HAS_TRANSFORMERS, 'Transformers required.')
    @pytest.mark.parametrize(*tmatrix({
        'repo_id': [
            "Falconsai/nsfw_image_detection",
            "microsoft/trocr-base-handwritten",
            "microsoft/trocr-base-printed",
            "dima806/facial_emotions_image_detection",
            "rizvandwiki/gender-classification",
            "AdamCodd/vit-base-nsfw-detector",
            "MixTex/ZhEn-Latex-OCR",
            "prithivMLmods/Deep-Fake-Detector-Model",
        ],
        'src_image': [
            'png_640.png',
            'png_640_m90.png',
            'nude_girl.png',
            'dori_640.png',
            'nian_640.png',
        ]
    }))
    def test_convnext_image_preprocess_align(self, src_image, repo_id):
        from transformers import AutoImageProcessor
        image = load_image(get_testfile(src_image), mode='RGB', force_background='white')
        processor = AutoImageProcessor.from_pretrained(repo_id)

        trans = create_transforms_from_transformers(processor)

        expected_output = processor.preprocess(image)['pixel_values'][0]
        output = trans(image)
        np.testing.assert_array_almost_equal(
            output,
            expected_output,
        )