Loading imgutils/preprocess/__init__.py +2 −1 Original line number Diff line number Diff line from .pillow import register_pillow_transform, create_pillow_transforms from .torchvision import register_torchvision_transform, create_torchvision_transforms from .torchvision import register_torchvision_transform, create_torchvision_transforms, \ register_torchvision_parse, parse_torchvision_transforms imgutils/preprocess/base.py 0 → 100644 +2 −0 Original line number Diff line number Diff line class NotParseTarget(Exception): pass imgutils/preprocess/torchvision.py +110 −7 Original line number Diff line number Diff line import copy from functools import wraps from typing import Union from .base import NotParseTarget def _check_torchvision(): try: Loading Loading @@ -44,7 +47,7 @@ def _get_interpolation_mode(value): _TRANS_CREATORS = {} def _register(name: str, safe: bool = True): def _register_transform(name: str, safe: bool = True): if safe: _check_torchvision() Loading @@ -56,10 +59,35 @@ def _register(name: str, safe: bool = True): def register_torchvision_transform(name: str): _register(name, safe=True) _register_transform(name, safe=True) _TRANS_PARSERS = {} def _register_parse(name: str, safe: bool = True): if safe: _check_torchvision() def _fn(func): @wraps(func) def _new_func(*args, **kwargs): return { 'type': name, **func(*args, **kwargs), } _TRANS_PARSERS[name] = _new_func return _new_func return _fn def register_torchvision_parse(name: str): _register_parse(name, safe=True) @_register('resize', safe=False) @_register_transform('resize', safe=False) def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True): from torchvision.transforms import Resize return Resize( Loading @@ -70,7 +98,22 @@ def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True ) @_register('center_crop', safe=False) @_register_parse('resize', safe=False) def _parse_resize(obj): from torchvision.transforms import Resize if not isinstance(obj, Resize): raise NotParseTarget obj: Resize return { 'size': obj.size, 'interpolation': obj.interpolation.value, 'max_size': obj.max_size, 'antialias': obj.antialias, } @_register_transform('center_crop', safe=False) def _create_center_crop(size): from torchvision.transforms import CenterCrop return CenterCrop( Loading @@ -78,7 +121,19 @@ def _create_center_crop(size): ) @_register('maybe_to_tensor', safe=False) @_register_parse('center_crop', safe=False) def _parse_center_crop(obj): from torchvision.transforms import CenterCrop if not isinstance(obj, CenterCrop): raise NotParseTarget obj: CenterCrop return { 'size': obj.size, } @_register_transform('maybe_to_tensor', safe=False) def _create_maybe_to_tensor(): from torchvision.transforms import ToTensor class MaybeToTensor(ToTensor): Loading @@ -98,13 +153,29 @@ def _create_maybe_to_tensor(): return MaybeToTensor() @_register('to_tensor', safe=False) @_register_parse('maybe_to_tensor', safe=False) def _parse_maybe_to_tensor(obj): if type(obj).__name__ != 'MaybeToTensor': raise NotParseTarget return {} @_register_transform('to_tensor', safe=False) def _create_to_tensor(): from torchvision.transforms import ToTensor return ToTensor() @_register('normalize', safe=False) @_register_parse('to_tensor', safe=False) def _parse_to_tensor(obj): if type(obj).__name__ != 'ToTensor': raise NotParseTarget return {} @_register_transform('normalize', safe=False) def _create_normalize(mean, std, inplace=False): import torch from torchvision.transforms import Normalize Loading @@ -115,6 +186,19 @@ def _create_normalize(mean, std, inplace=False): ) @_register_parse('normalize', safe=False) def _parse_normalize(obj): from torchvision.transforms import Normalize if not isinstance(obj, Normalize): raise NotParseTarget obj: Normalize return { 'mean': obj.mean.tolist(), 'std': obj.std.tolist(), } def create_torchvision_transforms(tvalue: Union[list, dict]): _check_torchvision() Loading @@ -127,3 +211,22 @@ def create_torchvision_transforms(tvalue: Union[list, dict]): return _TRANS_CREATORS[ttype](**tvalue) else: raise TypeError(f'Unknown type of transforms - {tvalue!r}.') def parse_torchvision_transforms(value): _check_torchvision() from torchvision.transforms import Compose if isinstance(value, Compose): return [ parse_torchvision_transforms(trans) for trans in value.transforms ] else: for key, _parser in _TRANS_PARSERS.items(): try: return _parser(value) except NotParseTarget: pass raise TypeError(f'Unknown parse transform - {value!r}.') test/preprocess/test_torchvision.py +57 −1 Original line number Diff line number Diff line Loading @@ -4,7 +4,8 @@ import pytest from PIL import Image from hbutils.testing import tmatrix from imgutils.preprocess.torchvision import _get_interpolation_mode, create_torchvision_transforms from imgutils.preprocess.torchvision import _get_interpolation_mode, create_torchvision_transforms, \ parse_torchvision_transforms from test.testings import get_testfile try: Loading Loading @@ -132,3 +133,58 @@ class TestPreprocessPillow: def test_create_transform_non_torchvision(self): with pytest.raises(EnvironmentError): _ = create_torchvision_transforms([]) @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required') def test_parse_torchvision_transforms(self): import torch from torchvision.transforms import Compose, Resize, InterpolationMode, CenterCrop, Normalize, ToTensor assert parse_torchvision_transforms(Compose([ Resize(size=384, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True), CenterCrop(size=[384, 384]), create_torchvision_transforms({'type': 'maybe_to_tensor'}), Normalize(mean=torch.tensor([0.5000, 0.5000, 0.5000]), std=torch.tensor([0.5000, 0.5000, 0.5000])), ])) == [ {'antialias': True, 'interpolation': 'bicubic', 'max_size': None, 'size': 384, 'type': 'resize'}, {'size': [384, 384], 'type': 'center_crop'}, {'type': 'maybe_to_tensor'}, {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'type': 'normalize'} ] assert parse_torchvision_transforms(Compose([ Resize(size=384, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True), CenterCrop(size=[384, 384]), ToTensor(), Normalize(mean=torch.tensor([0.5000, 0.5000, 0.5000]), std=torch.tensor([0.5000, 0.5000, 0.5000])), ])) == [ {'antialias': True, 'interpolation': 'bicubic', 'max_size': None, 'size': 384, 'type': 'resize'}, {'size': [384, 384], 'type': 'center_crop'}, {'type': 'to_tensor'}, {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'type': 'normalize'} ] assert parse_torchvision_transforms( Resize(size=384, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True)) \ == {'antialias': True, 'interpolation': 'bicubic', 'max_size': None, 'size': 384, 'type': 'resize'} assert parse_torchvision_transforms(CenterCrop(size=[384, 384])) == {'size': [384, 384], 'type': 'center_crop'} assert parse_torchvision_transforms(ToTensor()) == {'type': 'to_tensor'} assert parse_torchvision_transforms( Normalize(mean=torch.tensor([0.5000, 0.5000, 0.5000]), std=torch.tensor([0.5000, 0.5000, 0.5000]))) \ == {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'type': 'normalize'} with pytest.raises(TypeError): _ = parse_torchvision_transforms(None) with pytest.raises(TypeError): _ = parse_torchvision_transforms(23344) Loading
imgutils/preprocess/__init__.py +2 −1 Original line number Diff line number Diff line from .pillow import register_pillow_transform, create_pillow_transforms from .torchvision import register_torchvision_transform, create_torchvision_transforms from .torchvision import register_torchvision_transform, create_torchvision_transforms, \ register_torchvision_parse, parse_torchvision_transforms
imgutils/preprocess/base.py 0 → 100644 +2 −0 Original line number Diff line number Diff line class NotParseTarget(Exception): pass
imgutils/preprocess/torchvision.py +110 −7 Original line number Diff line number Diff line import copy from functools import wraps from typing import Union from .base import NotParseTarget def _check_torchvision(): try: Loading Loading @@ -44,7 +47,7 @@ def _get_interpolation_mode(value): _TRANS_CREATORS = {} def _register(name: str, safe: bool = True): def _register_transform(name: str, safe: bool = True): if safe: _check_torchvision() Loading @@ -56,10 +59,35 @@ def _register(name: str, safe: bool = True): def register_torchvision_transform(name: str): _register(name, safe=True) _register_transform(name, safe=True) _TRANS_PARSERS = {} def _register_parse(name: str, safe: bool = True): if safe: _check_torchvision() def _fn(func): @wraps(func) def _new_func(*args, **kwargs): return { 'type': name, **func(*args, **kwargs), } _TRANS_PARSERS[name] = _new_func return _new_func return _fn def register_torchvision_parse(name: str): _register_parse(name, safe=True) @_register('resize', safe=False) @_register_transform('resize', safe=False) def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True): from torchvision.transforms import Resize return Resize( Loading @@ -70,7 +98,22 @@ def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True ) @_register('center_crop', safe=False) @_register_parse('resize', safe=False) def _parse_resize(obj): from torchvision.transforms import Resize if not isinstance(obj, Resize): raise NotParseTarget obj: Resize return { 'size': obj.size, 'interpolation': obj.interpolation.value, 'max_size': obj.max_size, 'antialias': obj.antialias, } @_register_transform('center_crop', safe=False) def _create_center_crop(size): from torchvision.transforms import CenterCrop return CenterCrop( Loading @@ -78,7 +121,19 @@ def _create_center_crop(size): ) @_register('maybe_to_tensor', safe=False) @_register_parse('center_crop', safe=False) def _parse_center_crop(obj): from torchvision.transforms import CenterCrop if not isinstance(obj, CenterCrop): raise NotParseTarget obj: CenterCrop return { 'size': obj.size, } @_register_transform('maybe_to_tensor', safe=False) def _create_maybe_to_tensor(): from torchvision.transforms import ToTensor class MaybeToTensor(ToTensor): Loading @@ -98,13 +153,29 @@ def _create_maybe_to_tensor(): return MaybeToTensor() @_register('to_tensor', safe=False) @_register_parse('maybe_to_tensor', safe=False) def _parse_maybe_to_tensor(obj): if type(obj).__name__ != 'MaybeToTensor': raise NotParseTarget return {} @_register_transform('to_tensor', safe=False) def _create_to_tensor(): from torchvision.transforms import ToTensor return ToTensor() @_register('normalize', safe=False) @_register_parse('to_tensor', safe=False) def _parse_to_tensor(obj): if type(obj).__name__ != 'ToTensor': raise NotParseTarget return {} @_register_transform('normalize', safe=False) def _create_normalize(mean, std, inplace=False): import torch from torchvision.transforms import Normalize Loading @@ -115,6 +186,19 @@ def _create_normalize(mean, std, inplace=False): ) @_register_parse('normalize', safe=False) def _parse_normalize(obj): from torchvision.transforms import Normalize if not isinstance(obj, Normalize): raise NotParseTarget obj: Normalize return { 'mean': obj.mean.tolist(), 'std': obj.std.tolist(), } def create_torchvision_transforms(tvalue: Union[list, dict]): _check_torchvision() Loading @@ -127,3 +211,22 @@ def create_torchvision_transforms(tvalue: Union[list, dict]): return _TRANS_CREATORS[ttype](**tvalue) else: raise TypeError(f'Unknown type of transforms - {tvalue!r}.') def parse_torchvision_transforms(value): _check_torchvision() from torchvision.transforms import Compose if isinstance(value, Compose): return [ parse_torchvision_transforms(trans) for trans in value.transforms ] else: for key, _parser in _TRANS_PARSERS.items(): try: return _parser(value) except NotParseTarget: pass raise TypeError(f'Unknown parse transform - {value!r}.')
test/preprocess/test_torchvision.py +57 −1 Original line number Diff line number Diff line Loading @@ -4,7 +4,8 @@ import pytest from PIL import Image from hbutils.testing import tmatrix from imgutils.preprocess.torchvision import _get_interpolation_mode, create_torchvision_transforms from imgutils.preprocess.torchvision import _get_interpolation_mode, create_torchvision_transforms, \ parse_torchvision_transforms from test.testings import get_testfile try: Loading Loading @@ -132,3 +133,58 @@ class TestPreprocessPillow: def test_create_transform_non_torchvision(self): with pytest.raises(EnvironmentError): _ = create_torchvision_transforms([]) @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required') def test_parse_torchvision_transforms(self): import torch from torchvision.transforms import Compose, Resize, InterpolationMode, CenterCrop, Normalize, ToTensor assert parse_torchvision_transforms(Compose([ Resize(size=384, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True), CenterCrop(size=[384, 384]), create_torchvision_transforms({'type': 'maybe_to_tensor'}), Normalize(mean=torch.tensor([0.5000, 0.5000, 0.5000]), std=torch.tensor([0.5000, 0.5000, 0.5000])), ])) == [ {'antialias': True, 'interpolation': 'bicubic', 'max_size': None, 'size': 384, 'type': 'resize'}, {'size': [384, 384], 'type': 'center_crop'}, {'type': 'maybe_to_tensor'}, {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'type': 'normalize'} ] assert parse_torchvision_transforms(Compose([ Resize(size=384, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True), CenterCrop(size=[384, 384]), ToTensor(), Normalize(mean=torch.tensor([0.5000, 0.5000, 0.5000]), std=torch.tensor([0.5000, 0.5000, 0.5000])), ])) == [ {'antialias': True, 'interpolation': 'bicubic', 'max_size': None, 'size': 384, 'type': 'resize'}, {'size': [384, 384], 'type': 'center_crop'}, {'type': 'to_tensor'}, {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'type': 'normalize'} ] assert parse_torchvision_transforms( Resize(size=384, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True)) \ == {'antialias': True, 'interpolation': 'bicubic', 'max_size': None, 'size': 384, 'type': 'resize'} assert parse_torchvision_transforms(CenterCrop(size=[384, 384])) == {'size': [384, 384], 'type': 'center_crop'} assert parse_torchvision_transforms(ToTensor()) == {'type': 'to_tensor'} assert parse_torchvision_transforms( Normalize(mean=torch.tensor([0.5000, 0.5000, 0.5000]), std=torch.tensor([0.5000, 0.5000, 0.5000]))) \ == {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'type': 'normalize'} with pytest.raises(TypeError): _ = parse_torchvision_transforms(None) with pytest.raises(TypeError): _ = parse_torchvision_transforms(23344)