Commit f1b6adbd authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add other transforms

parent 3e265a30
Loading
Loading
Loading
Loading
+60 −1
Original line number Diff line number Diff line
from unittest import skipUnless

import pytest
from PIL import Image
from hbutils.testing import tmatrix

from imgutils.preprocess.torchvision import _get_interpolation_mode
from imgutils.preprocess.torchvision import _get_interpolation_mode, create_torchvision_transforms
from test.testings import get_testfile

try:
    import torchvision
@@ -58,3 +61,59 @@ class TestPreprocessPillow:
            _ = _get_interpolation_mode(None)
        with pytest.raises(TypeError):
            _ = _get_interpolation_mode([])

    @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required.')
    @pytest.mark.parametrize(*tmatrix({
        'src_image': [
            'png_640.png',
            'png_640_m90.png',
        ],
        ('mode', 'channels'): [
            ('I', 1), ('I;16', 1), ('F', 1), ('1', 1), ('L', 1), ('P', 1),
            ('LA', 2),
            ('RGB', 3), ('YCbCr', 3),
            ('RGBA', 4), ('CMYK', 4),
        ]
    }))
    def test_maybe_to_tensor(self, src_image, mode, channels):
        image = Image.open(get_testfile(src_image))
        image = image.convert(mode)
        assert image.mode == mode

        ttrans = create_torchvision_transforms({'type': 'maybe_to_tensor'})
        result = ttrans(image)
        assert tuple(result.shape) == (channels, image.height, image.width)

    @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required.')
    def test_maybe_to_tensor_repr(self):
        ttrans = create_torchvision_transforms({'type': 'maybe_to_tensor'})
        assert repr(ttrans) == 'MaybeToTensor()'

    @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required.')
    def test_maybe_to_tensor_np(self):
        import torch
        input_ = torch.randn(3, 384, 384)
        ttrans = create_torchvision_transforms({'type': 'maybe_to_tensor'})
        assert torch.allclose(ttrans(input_), input_)

    @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required.')
    @pytest.mark.parametrize(*tmatrix({
        'src_image': [
            'png_640.png',
            'png_640_m90.png',
        ],
        ('mode', 'channels'): [
            ('I', 1), ('I;16', 1), ('F', 1), ('1', 1), ('L', 1), ('P', 1),
            ('LA', 2),
            ('RGB', 3), ('YCbCr', 3),
            ('RGBA', 4), ('CMYK', 4),
        ]
    }))
    def test_to_tensor(self, src_image, mode, channels):
        image = Image.open(get_testfile(src_image))
        image = image.convert(mode)
        assert image.mode == mode

        ttrans = create_torchvision_transforms({'type': 'to_tensor'})
        result = ttrans(image)
        assert tuple(result.shape) == (channels, image.height, image.width)