Commit ecd829a1 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add unittest extra parse

parent fe3cbb93
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
from .base import NotParseTarget
from .pillow import register_pillow_transform, create_pillow_transforms, \
    register_pillow_parse, parse_pillow_transforms
from .torchvision import register_torchvision_transform, create_torchvision_transforms, \
+2 −2
Original line number Diff line number Diff line
@@ -59,7 +59,7 @@ def _register_transform(name: str, safe: bool = True):


def register_torchvision_transform(name: str):
    _register_transform(name, safe=True)
    return _register_transform(name, safe=True)


_TRANS_PARSERS = {}
@@ -84,7 +84,7 @@ def _register_parse(name: str, safe: bool = True):


def register_torchvision_parse(name: str):
    _register_parse(name, safe=True)
    return _register_parse(name, safe=True)


@_register_transform('resize', safe=False)
+48 −1
Original line number Diff line number Diff line
from typing import Union, Tuple
from unittest import skipUnless

import pytest
from PIL import Image
from hbutils.testing import tmatrix

from imgutils.preprocess import NotParseTarget
from imgutils.preprocess.torchvision import _get_interpolation_mode, create_torchvision_transforms, \
    parse_torchvision_transforms
    parse_torchvision_transforms, register_torchvision_transform, register_torchvision_parse
from test.testings import get_testfile

try:
@@ -188,3 +190,48 @@ class TestPreprocessPillow:
            _ = parse_torchvision_transforms(None)
        with pytest.raises(TypeError):
            _ = parse_torchvision_transforms(23344)

    @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required.')
    def test_register_and_use(self):
        from torchvision.transforms import ColorJitter
        @register_torchvision_transform('color_jitter')
        def _create_color_jitter(
                brightness: Union[float, Tuple[float, float]] = 0,
                contrast: Union[float, Tuple[float, float]] = 0,
                saturation: Union[float, Tuple[float, float]] = 0,
                hue: Union[float, Tuple[float, float]] = 0,
        ):
            return ColorJitter(brightness, contrast, saturation, hue)

        @register_torchvision_parse('color_jitter')
        def _parse_color_jitter(obj: ColorJitter):
            if not isinstance(obj, ColorJitter):
                raise NotParseTarget

            return {
                'brightness': obj.brightness,
                'contrast': obj.contrast,
                'saturation': obj.saturation,
                'hue': obj.hue,
            }

        c = create_torchvision_transforms({
            'type': 'color_jitter',
            'brightness': 0.5,
            'contrast': 0.2,
            'saturation': (0.0, 0.8),
            'hue': (0.1, 0.45),
        })
        assert isinstance(c, ColorJitter)
        assert c.brightness == pytest.approx((0.5, 1.5))
        assert c.contrast == pytest.approx((0.8, 1.2))
        assert c.saturation == pytest.approx((0.0, 0.8))
        assert c.hue == pytest.approx((0.1, 0.45))

        assert parse_torchvision_transforms(c) == pytest.approx({
            'brightness': (0.5, 1.5),
            'contrast': (0.8, 1.2),
            'hue': (0.1, 0.45),
            'saturation': (0.0, 0.8),
            'type': 'color_jitter'
        })