Commit de4f6b8f authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add 2 more pillow transforms

parent f411eb03
Loading
Loading
Loading
Loading
+34 −2
Original line number Diff line number Diff line
@@ -656,7 +656,7 @@ class PillowConvertRGB:
    def __init__(self, force_background: Optional[str] = 'white'):
        self.force_background = force_background

    def forward(self, pic):
    def __call__(self, pic):
        if not isinstance(pic, Image.Image):
            raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))
        return load_image(pic, mode='RGB', force_background=self.force_background)
@@ -665,9 +665,25 @@ class PillowConvertRGB:
        return f'{self.__class__.__name__}(force_background={self.force_background!r})'


@register_pillow_transform('convert_rgb')
def _create_convert_rgb(force_background: Optional[str] = 'white'):
    return PillowConvertRGB(force_background=force_background)


@register_pillow_parse('convert_rgb')
def _parse_convert_rgb(obj):
    if not isinstance(obj, PillowConvertRGB):
        raise NotParseTarget

    obj: PillowConvertRGB
    return {
        'force_background': obj.force_background,
    }


class PillowRescale:
    def __init__(self, rescale_factor: float = 1 / 255):
        self.rescale_factor = rescale_factor
        self.rescale_factor = np.float32(rescale_factor)

    def __call__(self, array):
        if not isinstance(array, np.ndarray):
@@ -678,6 +694,22 @@ class PillowRescale:
        return f'{self.__class__.__name__}(rescale_factor={self.rescale_factor!r})'


@register_pillow_transform('rescale')
def _create_rescale(rescale_factor: float = 1 / 255):
    return PillowRescale(rescale_factor=rescale_factor)


@register_pillow_parse('rescale')
def _parse_rescale(obj):
    if not isinstance(obj, PillowRescale):
        raise NotParseTarget

    obj: PillowRescale
    return {
        'rescale_factor': obj.rescale_factor.item(),
    }


class PillowCompose:
    """
    Composes several transforms together into a single transform.
+144 −1
Original line number Diff line number Diff line
@@ -5,8 +5,10 @@ import pytest
from PIL import Image
from hbutils.testing import tmatrix

from imgutils.data import load_image
from imgutils.preprocess.pillow import PillowResize, _get_pillow_resample, PillowCenterCrop, PillowToTensor, \
    PillowMaybeToTensor, PillowNormalize, create_pillow_transforms, parse_pillow_transforms, PillowCompose
    PillowMaybeToTensor, PillowNormalize, create_pillow_transforms, parse_pillow_transforms, PillowCompose, \
    PillowConvertRGB, PillowRescale
from imgutils.preprocess.torchvision import _get_interpolation_mode
from test.testings import get_testfile

@@ -867,3 +869,144 @@ class TestPreprocessPillow:
            _ = parse_pillow_transforms(None)
        with pytest.raises(TypeError):
            _ = parse_pillow_transforms(23344)

    @pytest.mark.parametrize(*tmatrix({
        'src_image': [
            'png_640.png',
            'png_640_m90.png',
            'dori_640.png',
            'nian_640.png',
        ],
        'bg_color': [
            'white', 'green', 'blue', 'red', 'black',
        ]
    }))
    def test_convert_rgb(self, src_image, bg_color, image_diff):
        image = Image.open(get_testfile(src_image))
        pconvertrgb = PillowConvertRGB(force_background=bg_color)
        dst_image = pconvertrgb(image)
        assert dst_image.mode == 'RGB'

        assert image_diff(
            dst_image,
            load_image(image, force_background=bg_color, mode='RGB'),
            throw_exception=False
        ) < 1e-3

    def test_convert_rgb_invalid_input(self):
        pconvertrgb = PillowConvertRGB()
        with pytest.raises(TypeError):
            pconvertrgb(np.random.randn(1, 3, 384, 384))

    @pytest.mark.parametrize(['color', 'repr_text'], [
        (None, "PillowConvertRGB(force_background='white')"),
        ('white', "PillowConvertRGB(force_background='white')"),
        ('black', "PillowConvertRGB(force_background='black')"),
        ('red', "PillowConvertRGB(force_background='red')"),
        ('green', "PillowConvertRGB(force_background='green')"),
        ('blue', "PillowConvertRGB(force_background='blue')"),
    ])
    def test_convert_rgb_repr(self, color, repr_text):
        pconvertrgb = PillowConvertRGB() if color is None else PillowConvertRGB(color)
        assert repr(pconvertrgb) == repr_text

    @pytest.mark.parametrize(['color', 'json_data'], [
        ('white', {'type': 'convert_rgb'}),
        ('white', {'type': 'convert_rgb', 'force_background': 'white'}),
        ('black', {'type': 'convert_rgb', 'force_background': 'black'}),
        ('red', {'type': 'convert_rgb', 'force_background': 'red'}),
        ('green', {'type': 'convert_rgb', 'force_background': 'green'}),
        ('blue', {'type': 'convert_rgb', 'force_background': 'blue'}),
    ])
    def test_create_convert_rgb(self, color, json_data):
        pconvertrgb = create_pillow_transforms(json_data)
        assert isinstance(pconvertrgb, PillowConvertRGB)
        assert pconvertrgb.force_background == color

    @pytest.mark.parametrize(['color', 'json_data'], [
        (None, {'type': 'convert_rgb', 'force_background': 'white'}),
        ('white', {'type': 'convert_rgb', 'force_background': 'white'}),
        ('black', {'type': 'convert_rgb', 'force_background': 'black'}),
        ('red', {'type': 'convert_rgb', 'force_background': 'red'}),
        ('green', {'type': 'convert_rgb', 'force_background': 'green'}),
        ('blue', {'type': 'convert_rgb', 'force_background': 'blue'}),
    ])
    def test_parse_convert_rgb(self, color, json_data):
        pconvertrgb = PillowConvertRGB() if color is None else PillowConvertRGB(color)
        assert parse_pillow_transforms(pconvertrgb) == json_data

    @pytest.mark.parametrize(*tmatrix({
        'seed': list(range(10)),
        'rescale_factor': [1 / 255, 1 / 254, 1 / 256, 1 / 127, 0.5, 0.1, 2, 10, 255, 254, 256],
    }))
    def test_rescale(self, seed, rescale_factor):
        np.random.seed(seed)
        arr = np.random.randn(3, 384, 384)
        prescale = PillowRescale(rescale_factor=rescale_factor)
        np.testing.assert_array_almost_equal(arr * rescale_factor, prescale(arr))

    @pytest.mark.parametrize(['rescale_factor', 'repr_text'], [
        (1 / 255, 'PillowRescale(rescale_factor=0.003921569)'),
        (1 / 254, 'PillowRescale(rescale_factor=0.003937008)'),
        (1 / 256, 'PillowRescale(rescale_factor=0.00390625)'),
        (1 / 127, 'PillowRescale(rescale_factor=0.007874016)'),
        (1 / 2, 'PillowRescale(rescale_factor=0.5)'),
        (1 / 10, 'PillowRescale(rescale_factor=0.1)'),
        (2, 'PillowRescale(rescale_factor=2.0)'),
        (10, 'PillowRescale(rescale_factor=10.0)'),
        (255, 'PillowRescale(rescale_factor=255.0)'),
        (254, 'PillowRescale(rescale_factor=254.0)'),
        (256, 'PillowRescale(rescale_factor=256.0)'),
    ])
    def test_rescale_repr(self, rescale_factor, repr_text):
        prescale = PillowRescale(rescale_factor)
        assert repr(prescale) == repr_text

    @pytest.mark.parametrize(*tmatrix({
        'src_image': [
            'png_640.png',
            'png_640_m90.png',
            'dori_640.png',
            'nian_640.png',
        ],
    }))
    def test_rescale_input_invalid(self, src_image):
        prescale = PillowRescale(1 / 255)
        image = Image.open(get_testfile(src_image))
        with pytest.raises(TypeError):
            _ = prescale(image)

    @pytest.mark.parametrize(['rescale_factor', 'json_data'], [
        (1 / 255, {'type': 'rescale', 'rescale_factor': 0.003921568859368563}),
        (1 / 254, {'type': 'rescale', 'rescale_factor': 0.003937007859349251}),
        (1 / 256, {'type': 'rescale', 'rescale_factor': 0.00390625}),
        (1 / 127, {'type': 'rescale', 'rescale_factor': 0.007874015718698502}),
        (1 / 2, {'type': 'rescale', 'rescale_factor': 0.5}),
        (1 / 10, {'type': 'rescale', 'rescale_factor': 0.10000000149011612}),
        (2, {'type': 'rescale', 'rescale_factor': 2.0}),
        (10, {'type': 'rescale', 'rescale_factor': 10.0}),
        (255, {'type': 'rescale', 'rescale_factor': 255.0}),
        (254, {'type': 'rescale', 'rescale_factor': 254.0}),
        (256, {'type': 'rescale', 'rescale_factor': 256.0}),
    ])
    def test_rescale_parse(self, rescale_factor, json_data):
        prescale = PillowRescale(rescale_factor)
        assert parse_pillow_transforms(prescale) == pytest.approx(json_data, abs=1e-5)

    @pytest.mark.parametrize(['rescale_factor', 'json_data'], [
        (1 / 255, {'type': 'rescale', 'rescale_factor': 0.003921568859368563}),
        (1 / 254, {'type': 'rescale', 'rescale_factor': 0.003937007859349251}),
        (1 / 256, {'type': 'rescale', 'rescale_factor': 0.00390625}),
        (1 / 127, {'type': 'rescale', 'rescale_factor': 0.007874015718698502}),
        (1 / 2, {'type': 'rescale', 'rescale_factor': 0.5}),
        (1 / 10, {'type': 'rescale', 'rescale_factor': 0.10000000149011612}),
        (2, {'type': 'rescale', 'rescale_factor': 2.0}),
        (10, {'type': 'rescale', 'rescale_factor': 10.0}),
        (255, {'type': 'rescale', 'rescale_factor': 255.0}),
        (254, {'type': 'rescale', 'rescale_factor': 254.0}),
        (256, {'type': 'rescale', 'rescale_factor': 256.0}),
    ])
    def test_rescale_create(self, rescale_factor, json_data):
        prescale = create_pillow_transforms(json_data)
        assert isinstance(prescale, PillowRescale)
        assert prescale.rescale_factor == pytest.approx(rescale_factor, abs=1e-5)
+388 KiB
Loading image diff...
+369 KiB
Loading image diff...