Commit 70c0f1ed authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): try fix that

parent 0bf4edec
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from .transparent import _rgba_preprocess, _rgba_postprocess
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model, area_batch_run

@@ -33,6 +34,7 @@ _CDC_INPUT_UNIT = 16
def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320',
                     tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1,
                     silent: bool = False) -> Image.Image:
    image, alpha_mask = _rgba_preprocess(image)
    image = load_image(image, mode='RGB', force_background='white')
    input_ = np.array(image).astype(np.float32) / 255.0
    input_ = input_.transpose((2, 0, 1))[None, ...]
@@ -60,4 +62,5 @@ def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320
        scale=scale, silent=silent, process_title='CDC Upscale',
    )
    output_ = np.clip(output_, a_min=0.0, a_max=1.0)
    return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    return [_rgba_postprocess(ret_image, alpha_mask, order=i) for i in range(6)]
+70 −0
Original line number Diff line number Diff line
from typing import Optional

import numpy as np
import scipy.ndimage
from PIL import Image

from ..data.image import ImageTyping, load_image


def _has_alpha_channel(image: Image.Image) -> bool:
    """
    Check if the image has an alpha channel.

    :param image: The image to check.
    :type image: Image.Image

    :return: True if the image has an alpha channel, False otherwise.
    :rtype: bool
    """
    return any(band in {'A', 'a', 'P'} for band in image.getbands())


def _rgba_preprocess(image: ImageTyping):
    """
    Preprocess the image for RGBA conversion.

    :param image: The image to preprocess.
    :type image: ImageTyping

    :return: Preprocessed image and alpha mask.
    :rtype: Tuple[Image.Image, Optional[np.ndarray]]
    """
    image = load_image(image, force_background=None, mode=None)
    if _has_alpha_channel(image):
        image = image.convert('RGBA')
        pimage = image.convert('RGB')
        alpha_mask = np.array(image)[:, :, 3].astype(np.float32) / 255.0
    else:
        pimage = image.convert('RGB')
        alpha_mask = None

    return pimage, alpha_mask


def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None, order: int = 1):
    """
    Postprocess the image after RGBA conversion.

    :param pimage: The processed image.
    :type pimage: Image.Image

    :param alpha_mask: The alpha mask.
    :type alpha_mask: Optional[np.ndarray]

    :return: Postprocessed image.
    :rtype: Image.Image
    """
    assert pimage.mode == 'RGB'
    if alpha_mask is None:
        return pimage
    else:
        channels = np.array(pimage)
        alpha_mask = scipy.ndimage.zoom(
            alpha_mask, np.array(channels.shape[:2]) / np.array(alpha_mask.shape),
            order=1, mode='nearest',
        )
        alpha_channel = (alpha_mask * 255.0).astype(np.uint8)[..., np.newaxis]
        rgba_channels = np.concatenate([channels, alpha_channel], axis=-1)
        assert rgba_channels.shape == (*channels.shape[:-1], 4)
        return Image.fromarray(rgba_channels, mode='RGBA')