Commit 3375ba3b authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): save x

parent e7b5c469
Loading
Loading
Loading
Loading
+12 −7
Original line number Diff line number Diff line
@@ -44,7 +44,8 @@ def _get_interpolation_mode(value: Union[int, str]):
_TRANS_CREATORS = {}


def register_torchvision_transform(name: str):
def _register(name: str, safe: bool = True):
    if safe:
        _check_torchvision()

    def _fn(func):
@@ -54,7 +55,11 @@ def register_torchvision_transform(name: str):
    return _fn


@register_torchvision_transform('resize')
def register_torchvision_transform(name: str):
    _register(name, safe=True)


@_register('resize', safe=False)
def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True):
    from torchvision.transforms import Resize
    return Resize(
@@ -65,7 +70,7 @@ def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True
    )


@register_torchvision_transform('center_crop')
@_register('center_crop', safe=False)
def _create_center_crop(size):
    from torchvision.transforms import CenterCrop
    return CenterCrop(
@@ -73,7 +78,7 @@ def _create_center_crop(size):
    )


@register_torchvision_transform('maybe_to_tensor')
@_register('maybe_to_tensor', safe=False)
def _create_maybe_to_tensor():
    from torchvision.transforms import ToTensor
    class MaybeToTensor(ToTensor):
@@ -93,13 +98,13 @@ def _create_maybe_to_tensor():
    return MaybeToTensor()


@register_torchvision_transform('to_tensor')
@_register('to_tensor', safe=False)
def _create_to_tensor():
    from torchvision.transforms import ToTensor
    return ToTensor()


@register_torchvision_transform('normalize')
@_register('normalize', safe=False)
def _create_normalize(mean, std, inplace=False):
    import torch
    from torchvision.transforms import Normalize
+1 −0
Original line number Diff line number Diff line
@@ -387,6 +387,7 @@ class TestPreprocessPillow:
    def test_center_crop_repr(self, size, repr_text):
        assert repr(PillowCenterCrop(size=size)) == repr_text

    @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required')
    @pytest.mark.parametrize(*tmatrix({
        'src_image': [
            'png_640.png',
+13 −18
Original line number Diff line number Diff line
@@ -18,7 +18,6 @@ else:
    _TORCHVISION_AVAILABLE = True


@pytest.fixture()
def torchvision_maybetotensor():
    from torchvision.transforms import ToTensor
    class MaybeToTensor(ToTensor):
@@ -38,19 +37,17 @@ def torchvision_maybetotensor():
    return MaybeToTensor()


@pytest.fixture()
def torchvision_mobilenet(torchvision_maybetotensor):
def torchvision_mobilenet():
    import torch
    from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, InterpolationMode
    return Compose([
        Resize(size=404, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True),
        CenterCrop(size=[384, 384]),
        torchvision_maybetotensor,
        torchvision_maybetotensor(),
        Normalize(mean=torch.tensor([0.4850, 0.4560, 0.4060]), std=torch.tensor([0.2290, 0.2240, 0.2250])),
    ])


@pytest.fixture()
def pillow_mobilenet():
    return PillowCompose([
        PillowResize(size=404, interpolation=Image.BICUBIC, max_size=None, antialias=True),
@@ -60,19 +57,17 @@ def pillow_mobilenet():
    ])


@pytest.fixture()
def torchvision_beit(torchvision_maybetotensor):
def torchvision_beit():
    import torch
    from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, InterpolationMode
    return Compose([
        Resize(size=384, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True),
        CenterCrop(size=[384, 384]),
        torchvision_maybetotensor,
        torchvision_maybetotensor(),
        Normalize(mean=torch.tensor([0.5000, 0.5000, 0.5000]), std=torch.tensor([0.5000, 0.5000, 0.5000])),
    ])


@pytest.fixture()
def pillow_beit():
    return PillowCompose([
        PillowResize(size=384, interpolation=Image.BICUBIC, max_size=None, antialias=True),
@@ -91,10 +86,10 @@ class TestPreprocessPillowCompose:
            'png_640_m90.png',
        ],
    }))
    def test_compose_mobilenet(self, src_image, pillow_mobilenet, torchvision_mobilenet):
    def test_compose_mobilenet(self, src_image):
        image = Image.open(get_testfile(src_image))
        presult = pillow_mobilenet(image)
        tresult = torchvision_mobilenet(image)
        presult = pillow_mobilenet()(image)
        tresult = torchvision_mobilenet()(image)
        np.testing.assert_array_almost_equal(presult, tresult.numpy())

    @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required.')
@@ -104,14 +99,14 @@ class TestPreprocessPillowCompose:
            'png_640_m90.png',
        ],
    }))
    def test_compose_beit(self, src_image, pillow_beit, torchvision_beit):
    def test_compose_beit(self, src_image):
        image = Image.open(get_testfile(src_image))
        presult = pillow_beit(image)
        tresult = torchvision_beit(image)
        presult = pillow_beit()(image)
        tresult = torchvision_beit()(image)
        np.testing.assert_array_almost_equal(presult, tresult.numpy())

    def test_compose_repr(self, pillow_mobilenet, pillow_beit):
        assert textwrap.dedent(repr(pillow_mobilenet)).strip() == \
    def test_compose_repr(self):
        assert textwrap.dedent(repr(pillow_mobilenet())).strip() == \
               textwrap.dedent("""
PillowCompose(
    PillowResize(size=404, interpolation=bicubic, max_size=None, antialias=True)
@@ -121,7 +116,7 @@ PillowCompose(
)
            """).strip()

        assert textwrap.dedent(repr(pillow_beit)).strip() == \
        assert textwrap.dedent(repr(pillow_beit())).strip() == \
               textwrap.dedent("""
PillowCompose(
    PillowResize(size=384, interpolation=bicubic, max_size=None, antialias=True)