Commit 7697fc29 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): prepare for the clip preprocessor

parent 1fa81a1d
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -47,6 +47,7 @@ jobs:
          python -m pip install -r requirements-model.txt
          python -m pip install -r requirements-doc.txt
          python -m pip install -r requirements-torchvision.txt
          python -m pip install -r requirements-transformers.txt
      - name: Prepare dataset
        uses: nick-fields/retry@v2
        if: ${{ github.event_name == 'push' }}
+1 −0
Original line number Diff line number Diff line
@@ -101,6 +101,7 @@ jobs:
        run: |
          pip install -r requirements-model.txt
          pip install -r requirements-torchvision.txt
          pip install -r requirements-transformers.txt
      - name: Test the basic environment
        shell: bash
        run: |
+27 −0
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ import numpy as np
from PIL import Image

from .base import NotParseTarget
from ..data import load_image

# noinspection PyUnresolvedReferences
_INT_TO_PILLOW = {
@@ -651,6 +652,32 @@ def _parse_normalize(obj: PillowNormalize):
    }


class PillowConvertRGB:
    def __init__(self, force_background: Optional[str] = 'white'):
        self.force_background = force_background

    def forward(self, pic):
        if not isinstance(pic, Image.Image):
            raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))
        return load_image(pic, mode='RGB', force_background=self.force_background)

    def __repr__(self):
        return f'{self.__class__.__name__}(force_background={self.force_background!r})'


class PillowRescale:
    def __init__(self, rescale_factor: float = 1 / 255):
        self.rescale_factor = rescale_factor

    def __call__(self, array):
        if not isinstance(array, np.ndarray):
            raise TypeError('Input should be a numpy.ndarray')
        return array * self.rescale_factor

    def __repr__(self):
        return f'{self.__class__.__name__}(rescale_factor={self.rescale_factor!r})'


class PillowCompose:
    """
    Composes several transforms together into a single transform.
+22 −10
Original line number Diff line number Diff line
@@ -13,6 +13,13 @@ from typing import Union

from .base import NotParseTarget

try:
    import torchvision
except (ImportError, ModuleNotFoundError):
    _HAS_TORCHVISION = False
else:
    _HAS_TORCHVISION = True


def _check_torchvision():
    """
@@ -20,9 +27,7 @@ def _check_torchvision():

    :raises EnvironmentError: If torchvision is not installed
    """
    try:
        import torchvision
    except (ImportError, ModuleNotFoundError):
    if not _HAS_TORCHVISION:
        raise EnvironmentError('No torchvision available.\n'
                               'Please install it by `pip install dghs-imgutils[torchvision]`.')

@@ -210,14 +215,10 @@ def _parse_center_crop(obj):
    }


@_register_transform('maybe_to_tensor', safe=False)
def _create_maybe_to_tensor():
    """
    Create a MaybeToTensor transform that converts input to tensor if not already a tensor.

    :return: MaybeToTensor transform
    """
if _HAS_TORCHVISION:
    from torchvision.transforms import ToTensor


    class MaybeToTensor(ToTensor):
        def __init__(self) -> None:
            super().__init__()
@@ -232,6 +233,17 @@ def _create_maybe_to_tensor():
        def __repr__(self) -> str:
            return f"{self.__class__.__name__}()"

else:
    MaybeToTensor = None


@_register_transform('maybe_to_tensor', safe=False)
def _create_maybe_to_tensor():
    """
    Create a MaybeToTensor transform that converts input to tensor if not already a tensor.

    :return: MaybeToTensor transform
    """
    return MaybeToTensor()


+0 −0

Empty file added.

Loading