Commit 5685cc5c authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add enhance for cdc upscaler

parent f72c4f4f
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
from .classify import *
from .enhance import *
+54 −0
Original line number Diff line number Diff line
import numpy as np
from PIL import Image

from ..data import ImageTyping, load_image

__all__ = [
    'ImageEnhancer',
]


def _has_alpha_channel(image: Image.Image) -> bool:
    """
    Check if the image has an alpha channel.

    :param image: The image to check.
    :type image: Image.Image

    :return: True if the image has an alpha channel, False otherwise.
    :rtype: bool
    """
    return any(band in {'A', 'a', 'P'} for band in image.getbands())


class ImageEnhancer:
    def _process_rgb(self, rgb_array: np.ndarray):
        # input: a (3, H, W) float32[0.0, 1.0] array
        # output: another (3, H', W') float32[0.0, 1.0] array
        raise NotImplementedError

    def _process_alpha_channel_with_model(self, alpha_array: np.ndarray):
        assert len(alpha_array.shape) == 2, f'Alpha array should be 2-dim, but {alpha_array.shape!r} found.'
        enhanced_alpha_array = self._process_rgb(np.stack([alpha_array, alpha_array, alpha_array])).mean(axis=0)
        return enhanced_alpha_array

    def _process_rgba(self, rgba_array: np.ndarray):
        assert len(rgba_array.shape) == 3 and rgba_array.shape[0] == 4, \
            f'RGBA array should be 3-dim and 4-channels, but {rgba_array.shape!r} found.'

        return np.concatenate([
            self._process_rgb(rgba_array[:3, ...]),
            self._process_alpha_channel_with_model(rgba_array[3, ...])[None, ...]
        ], axis=0)

    def process(self, image: ImageTyping):
        image = load_image(image, mode=None, force_background=None)
        mode = 'RGBA' if _has_alpha_channel(image) else 'RGB'
        image = load_image(image, mode=mode, force_background=None)
        input_array = (np.array(image).astype(np.float32) / 255.0).transpose((2, 0, 1))
        if _has_alpha_channel(image):
            output_array = self._process_rgba(input_array)
        else:
            output_array = self._process_rgb(input_array)
        output_array = (np.clip(output_array, a_min=0.0, a_max=1.0) * 255.0).astype(np.uint8).transpose((1, 2, 0))
        return Image.fromarray(output_array, mode=mode)
+5 −1
Original line number Diff line number Diff line
from .cdc import upscale_with_cdc
"""
Overview:
    Upscale image to a larger size.
"""
from .cdc import *
+63 −37
Original line number Diff line number Diff line
@@ -22,15 +22,18 @@ Overview:
from functools import lru_cache
from typing import Tuple, Any

import cv2
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from .transparent import _rgba_preprocess, _rgba_postprocess
from ..data import ImageTyping, load_image
from ..data import ImageTyping
from ..generic import ImageEnhancer
from ..utils import open_onnx_model, area_batch_run

__all__ = [
    'upscale_with_cdc',
]


@lru_cache()
def _open_cdc_upscaler_model(model: str) -> Tuple[Any, int]:
@@ -62,9 +65,64 @@ def _open_cdc_upscaler_model(model: str) -> Tuple[Any, int]:
_CDC_INPUT_UNIT = 16


def _upscale_for_rgb(input_: np.ndarray, model: str = 'HGSR-MHR-anime-aug_X4_320',
                     tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1, silent: bool = False):
    assert len(input_.shape) == 4 and input_.shape[:2] == (1, 3)
    ort, scale = _open_cdc_upscaler_model(model)

    def _method(ix):
        ix = ix.astype(np.float32)
        batch, channels, height, width = ix.shape
        p_height = 0 if height % _CDC_INPUT_UNIT == 0 else _CDC_INPUT_UNIT - (height % _CDC_INPUT_UNIT)
        p_width = 0 if width % _CDC_INPUT_UNIT == 0 else _CDC_INPUT_UNIT - (width % _CDC_INPUT_UNIT)
        if p_height > 0 or p_width > 0:  # align to 16
            ix = np.pad(ix, ((0, 0), (0, 0), (0, p_height), (0, p_width)), mode='reflect')
        actual_height, actual_width = height, width

        ox, = ort.run(['output'], {'input': ix})
        batch, channels, scale_, height, scale_, width = ox.shape
        ox = ox.reshape((batch, channels, scale_ * height, scale_ * width))
        ox = ox[..., :scale_ * actual_height, :scale_ * actual_width]  # crop back
        return ox

    output_ = area_batch_run(
        input_, _method,
        tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size,
        scale=scale, silent=silent, process_title='CDC Upscale',
    )
    output_ = np.clip(output_, a_min=0.0, a_max=1.0)
    return output_


class _Enhancer(ImageEnhancer):
    def __init__(self, model: str = 'HGSR-MHR-anime-aug_X4_320',
                 tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1, silent: bool = False):
        self.model = model
        self.tile_size = tile_size
        self.tile_overlap = tile_overlap
        self.batch_size = batch_size
        self.silent = silent

    def _process_rgb(self, rgb_array: np.ndarray):
        return _upscale_for_rgb(
            rgb_array[None, ...],
            model=self.model,
            tile_size=self.tile_size,
            tile_overlap=self.tile_overlap,
            batch_size=self.batch_size,
            silent=self.silent,
        )[0]


@lru_cache()
def _get_enhancer(model: str = 'HGSR-MHR-anime-aug_X4_320',
                  tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1, silent: bool = False):
    return _Enhancer(model, tile_size, tile_overlap, batch_size, silent)


def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320',
                     tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1,
                     alpha_interpolation: int = cv2.INTER_LINEAR, silent: bool = False, ) -> Image.Image:
                     silent: bool = False) -> Image.Image:
    """
    Upscale the input image using the CDC upscaler model.

@@ -83,9 +141,6 @@ def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320
    :param batch_size: The batch size. (default: 1)
    :type batch_size: int

    :param alpha_interpolation: Interpolation for :func:`cv2.resize`. Default is ``cv2.INTER_LINEAR``.
    :type alpha_interpolation: int

    :param silent: Whether to suppress progress messages. (default: False)
    :type silent: bool

@@ -107,33 +162,4 @@ def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320
        >>> upscale_with_cdc(image)
        <PIL.Image.Image image mode=RGBA size=4672x4672 at 0x7F0E48EDB640>
    """
    image, alpha_mask = _rgba_preprocess(image)
    image = load_image(image, mode='RGB', force_background='white')
    input_ = np.array(image).astype(np.float32) / 255.0
    input_ = input_.transpose((2, 0, 1))[None, ...]

    ort, scale = _open_cdc_upscaler_model(model)

    def _method(ix):
        ix = ix.astype(np.float32)
        batch, channels, height, width = ix.shape
        p_height = 0 if height % _CDC_INPUT_UNIT == 0 else _CDC_INPUT_UNIT - (height % _CDC_INPUT_UNIT)
        p_width = 0 if width % _CDC_INPUT_UNIT == 0 else _CDC_INPUT_UNIT - (width % _CDC_INPUT_UNIT)
        if p_height > 0 or p_width > 0:  # align to 16
            ix = np.pad(ix, ((0, 0), (0, 0), (0, p_height), (0, p_width)), mode='reflect')
        actual_height, actual_width = height, width

        ox, = ort.run(['output'], {'input': ix})
        batch, channels, scale_, height, scale_, width = ox.shape
        ox = ox.reshape((batch, channels, scale_ * height, scale_ * width))
        ox = ox[..., :scale_ * actual_height, :scale_ * actual_width]  # crop back
        return ox

    output_ = area_batch_run(
        input_, _method,
        tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size,
        scale=scale, silent=silent, process_title='CDC Upscale',
    )
    output_ = np.clip(output_, a_min=0.0, a_max=1.0)
    ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.uint8), 'RGB')
    return _rgba_postprocess(ret_image, alpha_mask, interpolation=alpha_interpolation)
    return _get_enhancer(model, tile_size, tile_overlap, batch_size, silent).process(image)

imgutils/upscale/transparent.py

deleted100644 → 0
+0 −71
Original line number Diff line number Diff line
from typing import Optional

import cv2
import numpy as np
from PIL import Image

from ..data.image import ImageTyping, load_image


def _has_alpha_channel(image: Image.Image) -> bool:
    """
    Check if the image has an alpha channel.

    :param image: The image to check.
    :type image: Image.Image

    :return: True if the image has an alpha channel, False otherwise.
    :rtype: bool
    """
    return any(band in {'A', 'a', 'P'} for band in image.getbands())


def _rgba_preprocess(image: ImageTyping):
    """
    Preprocess the image for RGBA conversion.

    :param image: The image to preprocess.
    :type image: ImageTyping

    :return: Preprocessed image and alpha mask.
    :rtype: Tuple[Image.Image, Optional[np.ndarray]]
    """
    image = load_image(image, force_background=None, mode=None)
    if _has_alpha_channel(image):
        image = image.convert('RGBA')
        pimage = image.convert('RGB')
        alpha_mask = np.array(image)[:, :, 3].astype(np.float32) / 255.0
    else:
        pimage = image.convert('RGB')
        alpha_mask = None

    return pimage, alpha_mask


def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None, interpolation: int = cv2.INTER_LINEAR):
    """
    Postprocess the image after RGBA conversion.

    :param pimage: The processed image.
    :type pimage: Image.Image

    :param alpha_mask: The alpha mask.
    :type alpha_mask: Optional[np.ndarray]

    :param interpolation: Interpolation for :func:`cv2.resize`. Default is ``cv2.INTER_LINEAR``.
    :type interpolation: int

    :return: Postprocessed image.
    :rtype: Image.Image
    """
    assert pimage.mode == 'RGB'
    if alpha_mask is None:
        return pimage
    else:
        channels = np.array(pimage)
        alpha_mask = cv2.resize(alpha_mask, channels.shape[:2], interpolation=interpolation)
        alpha_mask = np.clip(alpha_mask, a_min=0.0, a_max=1.0)
        alpha_channel = (alpha_mask * 255.0).astype(np.uint8)[..., np.newaxis]
        rgba_channels = np.concatenate([channels, alpha_channel], axis=-1)
        assert rgba_channels.shape == (*channels.shape[:-1], 4)
        return Image.fromarray(rgba_channels, mode='RGBA')
Loading