Commit 7f2ea70b authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add torchvision update to make sure usable

parent 387b3554
Loading
Loading
Loading
Loading
+55 −39
Original line number Diff line number Diff line
import copy
from typing import Union


def _check_torchvision():
    try:
        import torchvision
    except (ImportError, ModuleNotFoundError):
        raise EnvironmentError('No torchvision available.\n'
                           'Please install it by `pip install torchvision`.')
                               'Please install it by `pip install dghs-imgutils[torchvision]`.')

import torch
import torchvision.transforms.functional as F
from torchvision.transforms import InterpolationMode, Resize, Compose, CenterCrop, ToTensor, Normalize

def _get_interpolation_mode(value: Union[int, str]):
    from torchvision.transforms import InterpolationMode
    _INT_TO_INTERMODE = {
        0: InterpolationMode.NEAREST,
        2: InterpolationMode.BILINEAR,
@@ -25,8 +26,6 @@ _STR_TO_INTERMODE = {
        for key, value in InterpolationMode.__members__.items()
    }


def _get_interpolation_mode(value: Union[int, str, InterpolationMode]):
    if isinstance(value, InterpolationMode):
        return value
    elif isinstance(value, int):
@@ -46,6 +45,8 @@ _TRANS_CREATORS = {}


def register_torchvision_transform(name: str):
    _check_torchvision()

    def _fn(func):
        _TRANS_CREATORS[name] = func
        return func
@@ -54,7 +55,8 @@ def register_torchvision_transform(name: str):


@register_torchvision_transform('resize')
def _create_resize(size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=True):
def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True):
    from torchvision.transforms import Resize
    return Resize(
        size=size,
        interpolation=_get_interpolation_mode(interpolation),
@@ -65,16 +67,22 @@ def _create_resize(size, interpolation=InterpolationMode.BILINEAR, max_size=None

@register_torchvision_transform('center_crop')
def _create_center_crop(size):
    from torchvision.transforms import CenterCrop
    return CenterCrop(
        size=size,
    )


@register_torchvision_transform('maybe_to_tensor')
def _create_maybe_to_tensor():
    from torchvision.transforms import ToTensor
    class MaybeToTensor(ToTensor):
        def __init__(self) -> None:
            super().__init__()

        def __call__(self, pic):
            import torchvision.transforms.functional as F
            import torch
            if isinstance(pic, torch.Tensor):
                return pic
            return F.to_tensor(pic)
@@ -82,14 +90,19 @@ class MaybeToTensor(ToTensor):
        def __repr__(self) -> str:
            return f"{self.__class__.__name__}()"


@register_torchvision_transform('maybe_to_tensor')
def _create_maybe_to_tensor():
    return MaybeToTensor()


@register_torchvision_transform('to_tensor')
def _create_to_tensor():
    from torchvision.transforms import ToTensor
    return ToTensor()


@register_torchvision_transform('normalize')
def _create_normalize(mean, std, inplace=False):
    import torch
    from torchvision.transforms import Normalize
    return Normalize(
        mean=torch.tensor(mean),
        std=torch.tensor(std),
@@ -98,6 +111,9 @@ def _create_normalize(mean, std, inplace=False):


def create_torchvision_transforms(tvalue: Union[list, dict]):
    _check_torchvision()

    from torchvision.transforms import Compose
    if isinstance(tvalue, list):
        return Compose([create_torchvision_transforms(titem) for titem in tvalue])
    elif isinstance(tvalue, dict):