Loading imgutils/preprocess/torchvision.py +55 −39 Original line number Diff line number Diff line import copy from typing import Union def _check_torchvision(): try: import torchvision except (ImportError, ModuleNotFoundError): raise EnvironmentError('No torchvision available.\n' 'Please install it by `pip install torchvision`.') 'Please install it by `pip install dghs-imgutils[torchvision]`.') import torch import torchvision.transforms.functional as F from torchvision.transforms import InterpolationMode, Resize, Compose, CenterCrop, ToTensor, Normalize def _get_interpolation_mode(value: Union[int, str]): from torchvision.transforms import InterpolationMode _INT_TO_INTERMODE = { 0: InterpolationMode.NEAREST, 2: InterpolationMode.BILINEAR, Loading @@ -25,8 +26,6 @@ _STR_TO_INTERMODE = { for key, value in InterpolationMode.__members__.items() } def _get_interpolation_mode(value: Union[int, str, InterpolationMode]): if isinstance(value, InterpolationMode): return value elif isinstance(value, int): Loading @@ -46,6 +45,8 @@ _TRANS_CREATORS = {} def register_torchvision_transform(name: str): _check_torchvision() def _fn(func): _TRANS_CREATORS[name] = func return func Loading @@ -54,7 +55,8 @@ def register_torchvision_transform(name: str): @register_torchvision_transform('resize') def _create_resize(size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=True): def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True): from torchvision.transforms import Resize return Resize( size=size, interpolation=_get_interpolation_mode(interpolation), Loading @@ -65,16 +67,22 @@ def _create_resize(size, interpolation=InterpolationMode.BILINEAR, max_size=None @register_torchvision_transform('center_crop') def _create_center_crop(size): from torchvision.transforms import CenterCrop return CenterCrop( size=size, ) @register_torchvision_transform('maybe_to_tensor') def _create_maybe_to_tensor(): from torchvision.transforms import ToTensor class MaybeToTensor(ToTensor): def __init__(self) -> None: super().__init__() def __call__(self, pic): import torchvision.transforms.functional as F import torch if isinstance(pic, torch.Tensor): return pic return F.to_tensor(pic) Loading @@ -82,14 +90,19 @@ class MaybeToTensor(ToTensor): def __repr__(self) -> str: return f"{self.__class__.__name__}()" @register_torchvision_transform('maybe_to_tensor') def _create_maybe_to_tensor(): return MaybeToTensor() @register_torchvision_transform('to_tensor') def _create_to_tensor(): from torchvision.transforms import ToTensor return ToTensor() @register_torchvision_transform('normalize') def _create_normalize(mean, std, inplace=False): import torch from torchvision.transforms import Normalize return Normalize( mean=torch.tensor(mean), std=torch.tensor(std), Loading @@ -98,6 +111,9 @@ def _create_normalize(mean, std, inplace=False): def create_torchvision_transforms(tvalue: Union[list, dict]): _check_torchvision() from torchvision.transforms import Compose if isinstance(tvalue, list): return Compose([create_torchvision_transforms(titem) for titem in tvalue]) elif isinstance(tvalue, dict): Loading Loading
imgutils/preprocess/torchvision.py +55 −39 Original line number Diff line number Diff line import copy from typing import Union def _check_torchvision(): try: import torchvision except (ImportError, ModuleNotFoundError): raise EnvironmentError('No torchvision available.\n' 'Please install it by `pip install torchvision`.') 'Please install it by `pip install dghs-imgutils[torchvision]`.') import torch import torchvision.transforms.functional as F from torchvision.transforms import InterpolationMode, Resize, Compose, CenterCrop, ToTensor, Normalize def _get_interpolation_mode(value: Union[int, str]): from torchvision.transforms import InterpolationMode _INT_TO_INTERMODE = { 0: InterpolationMode.NEAREST, 2: InterpolationMode.BILINEAR, Loading @@ -25,8 +26,6 @@ _STR_TO_INTERMODE = { for key, value in InterpolationMode.__members__.items() } def _get_interpolation_mode(value: Union[int, str, InterpolationMode]): if isinstance(value, InterpolationMode): return value elif isinstance(value, int): Loading @@ -46,6 +45,8 @@ _TRANS_CREATORS = {} def register_torchvision_transform(name: str): _check_torchvision() def _fn(func): _TRANS_CREATORS[name] = func return func Loading @@ -54,7 +55,8 @@ def register_torchvision_transform(name: str): @register_torchvision_transform('resize') def _create_resize(size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=True): def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True): from torchvision.transforms import Resize return Resize( size=size, interpolation=_get_interpolation_mode(interpolation), Loading @@ -65,16 +67,22 @@ def _create_resize(size, interpolation=InterpolationMode.BILINEAR, max_size=None @register_torchvision_transform('center_crop') def _create_center_crop(size): from torchvision.transforms import CenterCrop return CenterCrop( size=size, ) @register_torchvision_transform('maybe_to_tensor') def _create_maybe_to_tensor(): from torchvision.transforms import ToTensor class MaybeToTensor(ToTensor): def __init__(self) -> None: super().__init__() def __call__(self, pic): import torchvision.transforms.functional as F import torch if isinstance(pic, torch.Tensor): return pic return F.to_tensor(pic) Loading @@ -82,14 +90,19 @@ class MaybeToTensor(ToTensor): def __repr__(self) -> str: return f"{self.__class__.__name__}()" @register_torchvision_transform('maybe_to_tensor') def _create_maybe_to_tensor(): return MaybeToTensor() @register_torchvision_transform('to_tensor') def _create_to_tensor(): from torchvision.transforms import ToTensor return ToTensor() @register_torchvision_transform('normalize') def _create_normalize(mean, std, inplace=False): import torch from torchvision.transforms import Normalize return Normalize( mean=torch.tensor(mean), std=torch.tensor(std), Loading @@ -98,6 +111,9 @@ def _create_normalize(mean, std, inplace=False): def create_torchvision_transforms(tvalue: Union[list, dict]): _check_torchvision() from torchvision.transforms import Compose if isinstance(tvalue, list): return Compose([create_torchvision_transforms(titem) for titem in tvalue]) elif isinstance(tvalue, dict): Loading