Commit 1cd70965 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): save current new code

parent 4f475a6d
Loading
Loading
Loading
Loading
+89 −9
Original line number Diff line number Diff line
@@ -112,6 +112,25 @@ def _get_int_from_interpolation_mode(value):
    return _INTERMODE_TO_INT[value]


def _get_interpolation_str_from_mode(value) -> str:
    """Convert InterpolationMode to string for F.interpolate"""
    from torchvision.transforms import InterpolationMode
    if not isinstance(value, InterpolationMode):
        raise TypeError(
            f'Unknown type of interpolation mode, cannot be transformed to int - {value!r}')  # pragma: no cover

    _INTERMODE_TO_STR = {
        InterpolationMode.NEAREST: 'nearest',
        InterpolationMode.BILINEAR: 'bilinear',
        InterpolationMode.BICUBIC: 'bicubic',
        # For modes not directly supported by F.interpolate, we map to the closest equivalent
        InterpolationMode.BOX: 'area',  # BOX is similar to area interpolation
        InterpolationMode.HAMMING: 'bilinear',  # No direct equivalent, use bilinear
        InterpolationMode.LANCZOS: 'bicubic',  # No direct equivalent, use bicubic
    }
    return _INTERMODE_TO_STR[value]


_TRANS_CREATORS = {}


@@ -370,6 +389,7 @@ def _parse_normalize(obj):

if _HAS_TORCHVISION:
    from torchvision.transforms import InterpolationMode
    import torch.nn.functional as F


    class PadToSize(torch.nn.Module):
@@ -396,6 +416,69 @@ if _HAS_TORCHVISION:
            self.interpolation: InterpolationMode = interpolation
            _parse_color_to_rgba(self.background_color)

        def _pad_pil_image(self, pic):
            return pad_image_to_size(
                pic=pic,
                size=self.size,
                background_color=self.background_color,
                interpolation=_get_int_from_interpolation_mode(self.interpolation),
            )

        def _pad_tensor(self, tensor):
            from ..data.pad import _parse_color_to_mode

            if tensor.ndim < 3 or tensor.ndim > 4:
                raise ValueError(f"Tensor should have 3 or 4 dimensions, got {tensor.ndim}")

                # Handle batched and unbatched tensors
            is_batched = tensor.ndim == 4
            if not is_batched:
                tensor = tensor.unsqueeze(0)

            # Get tensor properties
            b, c, h, w = tensor.shape
            target_w, target_h = self.size

            # Calculate new dimensions preserving aspect ratio
            ratio = min(target_w / w, target_h / h)
            new_h, new_w = round(h * ratio), round(w * ratio)

            # Resize tensor
            mode = _get_interpolation_str_from_mode(self.interpolation)
            resized = F.interpolate(
                tensor,
                size=(new_h, new_w),
                mode=mode,
                align_corners=None if mode == 'nearest' or mode == 'area' else False,
                antialias=True if mode in {'bicubic', 'bilinear'} else False,
            )

            # Create padded tensor with background color
            # noinspection PyTypeChecker
            bg_color = torch.tensor(_parse_color_to_mode(
                self.background_color,
                mode={1: 'L', 2: 'LA', 3: 'RGB', 4: 'RGBA'}[c]
            ), device=tensor.device)
            if tensor.dtype.is_floating_point:
                bg_color = (bg_color / 255.0).type(tensor.dtype)
            else:
                bg_color = bg_color.type(tensor.dtype)

            result = bg_color.reshape(1, c, 1, 1).expand(b, c, target_h, target_w).clone()

            # Calculate padding positions
            pad_left = (target_w - new_w) // 2
            pad_top = (target_h - new_h) // 2

            # Paste resized image onto padded background
            result[:, :, pad_top:pad_top + new_h, pad_left:pad_left + new_w] = resized

            # Return to original batch dimension if needed
            if not is_batched:
                result = result.squeeze(0)

            return result

        def forward(self, pic):
            """
            Apply padding transform to input image.
@@ -407,15 +490,12 @@ if _HAS_TORCHVISION:
            :rtype: PIL.Image.Image
            :raises TypeError: If input is not a PIL Image
            """
            if not isinstance(pic, Image.Image):
                raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))

            return pad_image_to_size(
                pic=pic,
                size=self.size,
                background_color=self.background_color,
                interpolation=_get_int_from_interpolation_mode(self.interpolation),
            )
            if isinstance(pic, Image.Image):
                return self._pad_pil_image(pic)
            elif isinstance(pic, torch.Tensor):
                return self._pad_tensor(pic)
            else:
                raise TypeError('pic should be PIL Image or a tensor. Got {}'.format(type(pic)))

        def __repr__(self) -> str:
            detail = f"(size={self.size}, interpolation={self.interpolation.value}, background_color={self.background_color})"