Commit e3f72c50 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add torchvision transforms

parent d1c5b642
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
from .pillow import register_pillow_transform, create_pillow_transforms
from .torchvision import register_torchvision_transform, create_torchvision_transforms
from .torchvision import register_torchvision_transform, create_torchvision_transforms, \
    register_torchvision_parse, parse_torchvision_transforms
+2 −0
Original line number Diff line number Diff line
class NotParseTarget(Exception):
    pass
+110 −7
Original line number Diff line number Diff line
import copy
from functools import wraps
from typing import Union

from .base import NotParseTarget


def _check_torchvision():
    try:
@@ -44,7 +47,7 @@ def _get_interpolation_mode(value):
_TRANS_CREATORS = {}


def _register(name: str, safe: bool = True):
def _register_transform(name: str, safe: bool = True):
    if safe:
        _check_torchvision()

@@ -56,10 +59,35 @@ def _register(name: str, safe: bool = True):


def register_torchvision_transform(name: str):
    _register(name, safe=True)
    _register_transform(name, safe=True)


_TRANS_PARSERS = {}


def _register_parse(name: str, safe: bool = True):
    if safe:
        _check_torchvision()

    def _fn(func):
        @wraps(func)
        def _new_func(*args, **kwargs):
            return {
                'type': name,
                **func(*args, **kwargs),
            }

        _TRANS_PARSERS[name] = _new_func
        return _new_func

    return _fn


def register_torchvision_parse(name: str):
    _register_parse(name, safe=True)

@_register('resize', safe=False)

@_register_transform('resize', safe=False)
def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True):
    from torchvision.transforms import Resize
    return Resize(
@@ -70,7 +98,22 @@ def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True
    )


@_register('center_crop', safe=False)
@_register_parse('resize', safe=False)
def _parse_resize(obj):
    from torchvision.transforms import Resize
    if not isinstance(obj, Resize):
        raise NotParseTarget

    obj: Resize
    return {
        'size': obj.size,
        'interpolation': obj.interpolation.value,
        'max_size': obj.max_size,
        'antialias': obj.antialias,
    }


@_register_transform('center_crop', safe=False)
def _create_center_crop(size):
    from torchvision.transforms import CenterCrop
    return CenterCrop(
@@ -78,7 +121,19 @@ def _create_center_crop(size):
    )


@_register('maybe_to_tensor', safe=False)
@_register_parse('center_crop', safe=False)
def _parse_center_crop(obj):
    from torchvision.transforms import CenterCrop
    if not isinstance(obj, CenterCrop):
        raise NotParseTarget

    obj: CenterCrop
    return {
        'size': obj.size,
    }


@_register_transform('maybe_to_tensor', safe=False)
def _create_maybe_to_tensor():
    from torchvision.transforms import ToTensor
    class MaybeToTensor(ToTensor):
@@ -98,13 +153,29 @@ def _create_maybe_to_tensor():
    return MaybeToTensor()


@_register('to_tensor', safe=False)
@_register_parse('maybe_to_tensor', safe=False)
def _parse_maybe_to_tensor(obj):
    if type(obj).__name__ != 'MaybeToTensor':
        raise NotParseTarget

    return {}


@_register_transform('to_tensor', safe=False)
def _create_to_tensor():
    from torchvision.transforms import ToTensor
    return ToTensor()


@_register('normalize', safe=False)
@_register_parse('to_tensor', safe=False)
def _parse_to_tensor(obj):
    if type(obj).__name__ != 'ToTensor':
        raise NotParseTarget

    return {}


@_register_transform('normalize', safe=False)
def _create_normalize(mean, std, inplace=False):
    import torch
    from torchvision.transforms import Normalize
@@ -115,6 +186,19 @@ def _create_normalize(mean, std, inplace=False):
    )


@_register_parse('normalize', safe=False)
def _parse_normalize(obj):
    from torchvision.transforms import Normalize
    if not isinstance(obj, Normalize):
        raise NotParseTarget

    obj: Normalize
    return {
        'mean': obj.mean.tolist(),
        'std': obj.std.tolist(),
    }


def create_torchvision_transforms(tvalue: Union[list, dict]):
    _check_torchvision()

@@ -127,3 +211,22 @@ def create_torchvision_transforms(tvalue: Union[list, dict]):
        return _TRANS_CREATORS[ttype](**tvalue)
    else:
        raise TypeError(f'Unknown type of transforms - {tvalue!r}.')


def parse_torchvision_transforms(value):
    _check_torchvision()

    from torchvision.transforms import Compose
    if isinstance(value, Compose):
        return [
            parse_torchvision_transforms(trans)
            for trans in value.transforms
        ]
    else:
        for key, _parser in _TRANS_PARSERS.items():
            try:
                return _parser(value)
            except NotParseTarget:
                pass

        raise TypeError(f'Unknown parse transform - {value!r}.')
+57 −1
Original line number Diff line number Diff line
@@ -4,7 +4,8 @@ import pytest
from PIL import Image
from hbutils.testing import tmatrix

from imgutils.preprocess.torchvision import _get_interpolation_mode, create_torchvision_transforms
from imgutils.preprocess.torchvision import _get_interpolation_mode, create_torchvision_transforms, \
    parse_torchvision_transforms
from test.testings import get_testfile

try:
@@ -132,3 +133,58 @@ class TestPreprocessPillow:
    def test_create_transform_non_torchvision(self):
        with pytest.raises(EnvironmentError):
            _ = create_torchvision_transforms([])

    @skipUnless(_TORCHVISION_AVAILABLE, 'Torchvision required')
    def test_parse_torchvision_transforms(self):
        import torch
        from torchvision.transforms import Compose, Resize, InterpolationMode, CenterCrop, Normalize, ToTensor

        assert parse_torchvision_transforms(Compose([
            Resize(size=384, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True),
            CenterCrop(size=[384, 384]),
            create_torchvision_transforms({'type': 'maybe_to_tensor'}),
            Normalize(mean=torch.tensor([0.5000, 0.5000, 0.5000]), std=torch.tensor([0.5000, 0.5000, 0.5000])),
        ])) == [
                   {'antialias': True,
                    'interpolation': 'bicubic',
                    'max_size': None,
                    'size': 384,
                    'type': 'resize'},
                   {'size': [384, 384], 'type': 'center_crop'},
                   {'type': 'maybe_to_tensor'},
                   {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'type': 'normalize'}
               ]

        assert parse_torchvision_transforms(Compose([
            Resize(size=384, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True),
            CenterCrop(size=[384, 384]),
            ToTensor(),
            Normalize(mean=torch.tensor([0.5000, 0.5000, 0.5000]), std=torch.tensor([0.5000, 0.5000, 0.5000])),
        ])) == [
                   {'antialias': True,
                    'interpolation': 'bicubic',
                    'max_size': None,
                    'size': 384,
                    'type': 'resize'},
                   {'size': [384, 384], 'type': 'center_crop'},
                   {'type': 'to_tensor'},
                   {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'type': 'normalize'}
               ]

        assert parse_torchvision_transforms(
            Resize(size=384, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True)) \
               == {'antialias': True,
                   'interpolation': 'bicubic',
                   'max_size': None,
                   'size': 384,
                   'type': 'resize'}
        assert parse_torchvision_transforms(CenterCrop(size=[384, 384])) == {'size': [384, 384], 'type': 'center_crop'}
        assert parse_torchvision_transforms(ToTensor()) == {'type': 'to_tensor'}
        assert parse_torchvision_transforms(
            Normalize(mean=torch.tensor([0.5000, 0.5000, 0.5000]), std=torch.tensor([0.5000, 0.5000, 0.5000]))) \
               == {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'type': 'normalize'}

        with pytest.raises(TypeError):
            _ = parse_torchvision_transforms(None)
        with pytest.raises(TypeError):
            _ = parse_torchvision_transforms(23344)