Commit ecc6a7b2 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add unittest for center crop

parent 0489e4df
Loading
Loading
Loading
Loading
+52 −1
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import pytest
from PIL import Image
from hbutils.testing import tmatrix

from imgutils.preprocess.pillow import PillowResize, _get_pillow_resample
from imgutils.preprocess.pillow import PillowResize, _get_pillow_resample, PillowCenterCrop
from imgutils.preprocess.torchvision import _get_interpolation_mode
from test.testings import get_testfile

@@ -323,3 +323,54 @@ class TestPreprocessPillow:
            antialias=antialias,
        )
        assert repr(size) == repr_text

    def test_center_crop_invalid(self):
        with pytest.raises(ValueError):
            _ = PillowCenterCrop(size='str')

    def test_center_crop_invali_call(self):
        center_crop = PillowCenterCrop(224)
        with pytest.raises(TypeError):
            _ = center_crop(np.random.randn(3, 284, 384))

    @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision available required.')
    @pytest.mark.parametrize(*tmatrix({
        'src_image': [
            'png_640.png',
            'png_640_m90.png',
        ],
        'size': [
            224,
            384,
            888,
        ]
    }))
    def test_center_crop(self, src_image, size, image_diff):
        from torchvision.transforms import CenterCrop
        image = Image.open(get_testfile(src_image))
        pcentercrop = PillowCenterCrop(size=size)
        tcentercrop = CenterCrop(size=size)
        assert image_diff(pcentercrop(image), tcentercrop(image), throw_exception=False) < 1e-3

    @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision available required.')
    @pytest.mark.parametrize(*tmatrix({
        'src_image': [
            'png_640.png',
            'png_640_m90.png',
        ],
        'size': [
            (224,),
            (384,),
            (888,),
            (224, 384),
            (384, 224),
            (384, 888),
            (888, 384),
        ]
    }))
    def test_center_crop_pair(self, src_image, size, image_diff):
        from torchvision.transforms import CenterCrop
        image = Image.open(get_testfile(src_image))
        pcentercrop = PillowCenterCrop(size=size)
        tcentercrop = CenterCrop(size=size)
        assert image_diff(pcentercrop(image), tcentercrop(image), throw_exception=False) < 1e-3