Commit 09a2b037 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): use cv2.resize for alpha resize

parent 641f4f23
Loading
Loading
Loading
Loading
+11 −2
Original line number Diff line number Diff line
@@ -8,6 +8,11 @@ Overview:
    .. image:: cdc_demo.plot.py.svg
        :align: center

    Here is the benchmark of CDC models:

    .. image:: cdc_benchmark.plot.py.svg
        :align: center

    .. note::
        CDC model has high quality, and really low running speed.
        As we tested, when it upscales an image with 1024x1024 resolution on 2060 GPU,
@@ -17,6 +22,7 @@ Overview:
from functools import lru_cache
from typing import Tuple, Any

import cv2
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
@@ -58,7 +64,7 @@ _CDC_INPUT_UNIT = 16

def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320',
                     tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1,
                     silent: bool = False) -> Image.Image:
                     alpha_interpolation: int = cv2.INTER_LINEAR, silent: bool = False, ) -> Image.Image:
    """
    Upscale the input image using the CDC upscaler model.

@@ -77,6 +83,9 @@ def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320
    :param batch_size: The batch size. (default: 1)
    :type batch_size: int

    :param alpha_interpolation: Interpolation for :func:`cv2.resize`. Default is ``cv2.INTER_LINEAR``.
    :type alpha_interpolation: int

    :param silent: Whether to suppress progress messages. (default: False)
    :type silent: bool

@@ -127,4 +136,4 @@ def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320
    )
    output_ = np.clip(output_, a_min=0.0, a_max=1.0)
    ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.uint8), 'RGB')
    return _rgba_postprocess(ret_image, alpha_mask)
    return _rgba_postprocess(ret_image, alpha_mask, interpolation=alpha_interpolation)
+7 −8
Original line number Diff line number Diff line
from typing import Optional

import cv2
import numpy as np
import scipy.ndimage
from PIL import Image

from ..data.image import ImageTyping, load_image
@@ -42,7 +42,7 @@ def _rgba_preprocess(image: ImageTyping):
    return pimage, alpha_mask


def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None):
def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None, interpolation: int = cv2.INTER_LINEAR):
    """
    Postprocess the image after RGBA conversion.

@@ -52,6 +52,9 @@ def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None):
    :param alpha_mask: The alpha mask.
    :type alpha_mask: Optional[np.ndarray]

    :param interpolation: Interpolation for :func:`cv2.resize`. Default is ``cv2.INTER_LINEAR``.
    :type interpolation: int

    :return: Postprocessed image.
    :rtype: Image.Image
    """
@@ -60,12 +63,8 @@ def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None):
        return pimage
    else:
        channels = np.array(pimage)
        alpha_mask = scipy.ndimage.zoom(
            alpha_mask,
            np.array(channels.shape[:2]) / np.array(alpha_mask.shape),
            mode='nearest',
            order=1,
        )
        alpha_mask = cv2.resize(alpha_mask, channels.shape[:2], interpolation=interpolation)
        alpha_mask = np.clip(alpha_mask, a_min=0.0, a_max=1.0)
        alpha_channel = (alpha_mask * 255.0).astype(np.uint8)[..., np.newaxis]
        rgba_channels = np.concatenate([channels, alpha_channel], axis=-1)
        assert rgba_channels.shape == (*channels.shape[:-1], 4)
+189 KiB
Loading image diff...
+1.81 MiB
Loading image diff...
+10 −0
Original line number Diff line number Diff line
@@ -12,3 +12,13 @@ def sample_image():
@pytest.fixture()
def sample_image_small(sample_image):
    yield sample_image.resize((127, 126))


@pytest.fixture()
def sample_rgba_image():
    yield load_image(get_testfile('rgba_upscale.png'), mode='RGBA', force_background=None)


@pytest.fixture()
def sample_rgba_image_4x():
    yield load_image(get_testfile('rgba_upscale_4x.png'), mode='RGBA', force_background=None)
Loading