Commit d11c6bd2 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): alignment test passed

parent 3375ba3b
Loading
Loading
Loading
Loading
+91 −1
Original line number Diff line number Diff line
@@ -7,7 +7,8 @@ from PIL import Image
from hbutils.testing import tmatrix

from imgutils.preprocess.pillow import PillowNormalize, PillowCompose, PillowResize, PillowMaybeToTensor, \
    PillowCenterCrop
    PillowCenterCrop, create_pillow_transforms
from imgutils.preprocess.torchvision import create_torchvision_transforms
from test.testings import get_testfile

try:
@@ -18,6 +19,71 @@ else:
    _TORCHVISION_AVAILABLE = True


@pytest.fixture()
def mobilenetv4_conv_large_e600_r384_in1k():
    return [{'antialias': True,
             'interpolation': 'bicubic',
             'max_size': None,
             'size': 404,
             'type': 'resize'},
            {'size': [384, 384], 'type': 'center_crop'},
            {'type': 'maybe_to_tensor'},
            {'mean': [0.48500001430511475, 0.4560000002384186, 0.4059999883174896],
             'std': [0.2290000021457672, 0.2240000069141388, 0.22499999403953552],
             'type': 'normalize'}]


@pytest.fixture()
def caformer_s36_sail_in1k_384():
    return [{'antialias': True,
             'interpolation': 'bicubic',
             'max_size': None,
             'size': 384,
             'type': 'resize'},
            {'size': [384, 384], 'type': 'center_crop'},
            {'type': 'maybe_to_tensor'},
            {'mean': [0.48500001430511475, 0.4560000002384186, 0.4059999883174896],
             'std': [0.2290000021457672, 0.2240000069141388, 0.22499999403953552],
             'type': 'normalize'}]


@pytest.fixture()
def beit_base_patch16_384_in22k_ft_in22k_in1k():
    return [{'antialias': True,
             'interpolation': 'bicubic',
             'max_size': None,
             'size': 384,
             'type': 'resize'},
            {'size': [384, 384], 'type': 'center_crop'},
            {'type': 'maybe_to_tensor'},
            {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'type': 'normalize'}]


@pytest.fixture()
def resnet101d_ra2_in1k():
    return [{'antialias': True,
             'interpolation': 'bicubic',
             'max_size': None,
             'size': 269,
             'type': 'resize'},
            {'size': [256, 256], 'type': 'center_crop'},
            {'type': 'maybe_to_tensor'},
            {'mean': [0.48500001430511475, 0.4560000002384186, 0.4059999883174896],
             'std': [0.2290000021457672, 0.2240000069141388, 0.22499999403953552],
             'type': 'normalize'}]


@pytest.fixture()
def meta_collect(mobilenetv4_conv_large_e600_r384_in1k, resnet101d_ra2_in1k, caformer_s36_sail_in1k_384,
                 beit_base_patch16_384_in22k_ft_in22k_in1k):
    return {
        'mobilenetv4_conv_large.e600_r384_in1k': mobilenetv4_conv_large_e600_r384_in1k,
        'resnet101d.ra2_in1k': resnet101d_ra2_in1k,
        'caformer_s36.sail_in1k_384': caformer_s36_sail_in1k_384,
        'beit_base_patch16_384.in22k_ft_in22k_in1k': beit_base_patch16_384_in22k_ft_in22k_in1k,
    }


def torchvision_maybetotensor():
    from torchvision.transforms import ToTensor
    class MaybeToTensor(ToTensor):
@@ -125,3 +191,27 @@ PillowCompose(
    PillowNormalize(mean=[0.5 0.5 0.5], std=[0.5 0.5 0.5])
)
            """).strip()

    @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required.')
    @pytest.mark.parametrize(*tmatrix({
        'src_image': [
            'png_640.png',
            'png_640_m90.png',
        ],
        'meta_name': [
            'mobilenetv4_conv_large.e600_r384_in1k',
            'caformer_s36.sail_in1k_384',
            'beit_base_patch16_384.in22k_ft_in22k_in1k',
            'resnet101d.ra2_in1k',
        ]
    }))
    def test_compose_alignment(self, src_image, meta_name, meta_collect):
        image = Image.open(get_testfile(src_image))
        meta = meta_collect[meta_name]

        ptrans = create_pillow_transforms(meta)
        ttrans = create_torchvision_transforms(meta)
        np.testing.assert_array_almost_equal(
            ptrans(image),
            ttrans(image).numpy(),
        )