Loading imgutils/preprocess/__init__.py +1 −0 Original line number Diff line number Diff line from .base import NotParseTarget from .pillow import register_pillow_transform, create_pillow_transforms, \ register_pillow_parse, parse_pillow_transforms from .torchvision import register_torchvision_transform, create_torchvision_transforms, \ Loading imgutils/preprocess/torchvision.py +2 −2 Original line number Diff line number Diff line Loading @@ -59,7 +59,7 @@ def _register_transform(name: str, safe: bool = True): def register_torchvision_transform(name: str): _register_transform(name, safe=True) return _register_transform(name, safe=True) _TRANS_PARSERS = {} Loading @@ -84,7 +84,7 @@ def _register_parse(name: str, safe: bool = True): def register_torchvision_parse(name: str): _register_parse(name, safe=True) return _register_parse(name, safe=True) @_register_transform('resize', safe=False) Loading test/preprocess/test_torchvision.py +48 −1 Original line number Diff line number Diff line from typing import Union, Tuple from unittest import skipUnless import pytest from PIL import Image from hbutils.testing import tmatrix from imgutils.preprocess import NotParseTarget from imgutils.preprocess.torchvision import _get_interpolation_mode, create_torchvision_transforms, \ parse_torchvision_transforms parse_torchvision_transforms, register_torchvision_transform, register_torchvision_parse from test.testings import get_testfile try: Loading Loading @@ -188,3 +190,48 @@ class TestPreprocessPillow: _ = parse_torchvision_transforms(None) with pytest.raises(TypeError): _ = parse_torchvision_transforms(23344) @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required.') def test_register_and_use(self): from torchvision.transforms import ColorJitter @register_torchvision_transform('color_jitter') def _create_color_jitter( brightness: Union[float, Tuple[float, float]] = 0, contrast: Union[float, Tuple[float, float]] = 0, saturation: Union[float, Tuple[float, float]] = 0, hue: Union[float, Tuple[float, float]] = 0, ): return ColorJitter(brightness, contrast, saturation, hue) @register_torchvision_parse('color_jitter') def _parse_color_jitter(obj: ColorJitter): if not isinstance(obj, ColorJitter): raise NotParseTarget return { 'brightness': obj.brightness, 'contrast': obj.contrast, 'saturation': obj.saturation, 'hue': obj.hue, } c = create_torchvision_transforms({ 'type': 'color_jitter', 'brightness': 0.5, 'contrast': 0.2, 'saturation': (0.0, 0.8), 'hue': (0.1, 0.45), }) assert isinstance(c, ColorJitter) assert c.brightness == pytest.approx((0.5, 1.5)) assert c.contrast == pytest.approx((0.8, 1.2)) assert c.saturation == pytest.approx((0.0, 0.8)) assert c.hue == pytest.approx((0.1, 0.45)) assert parse_torchvision_transforms(c) == pytest.approx({ 'brightness': (0.5, 1.5), 'contrast': (0.8, 1.2), 'hue': (0.1, 0.45), 'saturation': (0.0, 0.8), 'type': 'color_jitter' }) Loading
imgutils/preprocess/__init__.py +1 −0 Original line number Diff line number Diff line from .base import NotParseTarget from .pillow import register_pillow_transform, create_pillow_transforms, \ register_pillow_parse, parse_pillow_transforms from .torchvision import register_torchvision_transform, create_torchvision_transforms, \ Loading
imgutils/preprocess/torchvision.py +2 −2 Original line number Diff line number Diff line Loading @@ -59,7 +59,7 @@ def _register_transform(name: str, safe: bool = True): def register_torchvision_transform(name: str): _register_transform(name, safe=True) return _register_transform(name, safe=True) _TRANS_PARSERS = {} Loading @@ -84,7 +84,7 @@ def _register_parse(name: str, safe: bool = True): def register_torchvision_parse(name: str): _register_parse(name, safe=True) return _register_parse(name, safe=True) @_register_transform('resize', safe=False) Loading
test/preprocess/test_torchvision.py +48 −1 Original line number Diff line number Diff line from typing import Union, Tuple from unittest import skipUnless import pytest from PIL import Image from hbutils.testing import tmatrix from imgutils.preprocess import NotParseTarget from imgutils.preprocess.torchvision import _get_interpolation_mode, create_torchvision_transforms, \ parse_torchvision_transforms parse_torchvision_transforms, register_torchvision_transform, register_torchvision_parse from test.testings import get_testfile try: Loading Loading @@ -188,3 +190,48 @@ class TestPreprocessPillow: _ = parse_torchvision_transforms(None) with pytest.raises(TypeError): _ = parse_torchvision_transforms(23344) @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required.') def test_register_and_use(self): from torchvision.transforms import ColorJitter @register_torchvision_transform('color_jitter') def _create_color_jitter( brightness: Union[float, Tuple[float, float]] = 0, contrast: Union[float, Tuple[float, float]] = 0, saturation: Union[float, Tuple[float, float]] = 0, hue: Union[float, Tuple[float, float]] = 0, ): return ColorJitter(brightness, contrast, saturation, hue) @register_torchvision_parse('color_jitter') def _parse_color_jitter(obj: ColorJitter): if not isinstance(obj, ColorJitter): raise NotParseTarget return { 'brightness': obj.brightness, 'contrast': obj.contrast, 'saturation': obj.saturation, 'hue': obj.hue, } c = create_torchvision_transforms({ 'type': 'color_jitter', 'brightness': 0.5, 'contrast': 0.2, 'saturation': (0.0, 0.8), 'hue': (0.1, 0.45), }) assert isinstance(c, ColorJitter) assert c.brightness == pytest.approx((0.5, 1.5)) assert c.contrast == pytest.approx((0.8, 1.2)) assert c.saturation == pytest.approx((0.0, 0.8)) assert c.hue == pytest.approx((0.1, 0.45)) assert parse_torchvision_transforms(c) == pytest.approx({ 'brightness': (0.5, 1.5), 'contrast': (0.8, 1.2), 'hue': (0.1, 0.45), 'saturation': (0.0, 0.8), 'type': 'color_jitter' })