Commit 255988ec authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): use enhance layer

parent ba04f1cd
Loading
Loading
Loading
Loading
+34 −19
Original line number Diff line number Diff line
@@ -28,8 +28,8 @@ import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from .transparent import _rgba_preprocess, _rgba_postprocess
from ..data import ImageTyping, load_image
from ..data import ImageTyping
from ..generic import ImageEnhancer
from ..utils import open_onnx_model, area_batch_run

NafNetModelTyping = Literal['REDS', 'GoPro', 'SIDD']
@@ -50,6 +50,37 @@ def _open_nafnet_model(model: NafNetModelTyping):
    ))


class _Enhancer(ImageEnhancer):
    def __init__(self, model: NafNetModelTyping = 'REDS', tile_size: int = 256, tile_overlap: int = 16,
                 batch_size: int = 4, silent: bool = False):
        self.model = model
        self.tile_size = tile_size
        self.tile_overlap = tile_overlap
        self.batch_size = batch_size
        self.silent = silent

    def _process_rgb(self, rgb_array: np.ndarray):
        input_ = rgb_array[None, ...]

        def _method(ix):
            ox, = _open_nafnet_model(self.model).run(['output'], {'input': ix})
            return ox

        output_ = area_batch_run(
            input_, _method,
            tile_size=self.tile_size, tile_overlap=self.tile_overlap, batch_size=self.batch_size,
            silent=self.silent, process_title='NafNet Restore',
        )
        output_ = np.clip(output_, a_min=0.0, a_max=1.0)
        return output_[0]


@lru_cache()
def _get_enhancer(model: NafNetModelTyping = 'REDS', tile_size: int = 256, tile_overlap: int = 16,
                  batch_size: int = 4, silent: bool = False) -> _Enhancer:
    return _Enhancer(model, tile_size, tile_overlap, batch_size, silent)


def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS',
                        tile_size: int = 256, tile_overlap: int = 16, batch_size: int = 4,
                        silent: bool = False) -> Image.Image:
@@ -71,20 +102,4 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS',
    :return: The restored image.
    :rtype: Image.Image
    """
    image, alpha_mask = _rgba_preprocess(image)
    image = load_image(image, mode='RGB', force_background='white')
    input_ = np.array(image).astype(np.float32) / 255.0
    input_ = input_.transpose((2, 0, 1))[None, ...]

    def _method(ix):
        ox, = _open_nafnet_model(model).run(['output'], {'input': ix})
        return ox

    output_ = area_batch_run(
        input_, _method,
        tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size,
        silent=silent, process_title='NafNet Restore',
    )
    output_ = np.clip(output_, a_min=0.0, a_max=1.0)
    ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    return _rgba_postprocess(ret_image, alpha_mask)
    return _get_enhancer(model, tile_size, tile_overlap, batch_size, silent).process(image)
+34 −19
Original line number Diff line number Diff line
@@ -23,8 +23,8 @@ import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from .transparent import _rgba_preprocess, _rgba_postprocess
from ..data import ImageTyping, load_image
from ..data import ImageTyping
from ..generic import ImageEnhancer
from ..utils import open_onnx_model, area_batch_run

SCUNetModelTyping = Literal['GAN', 'PSNR']
@@ -45,6 +45,37 @@ def _open_scunet_model(model: SCUNetModelTyping):
    ))


class _Enhancer(ImageEnhancer):
    def __init__(self, model: SCUNetModelTyping = 'GAN', tile_size: int = 128, tile_overlap: int = 16,
                 batch_size: int = 4, silent: bool = False):
        self.model = model
        self.tile_size = tile_size
        self.tile_overlap = tile_overlap
        self.batch_size = batch_size
        self.silent = silent

    def _process_rgb(self, rgb_array: np.ndarray):
        input_ = rgb_array[None, ...]

        def _method(ix):
            ox, = _open_scunet_model(self.model).run(['output'], {'input': ix})
            return ox

        output_ = area_batch_run(
            input_, _method,
            tile_size=self.tile_size, tile_overlap=self.tile_overlap, batch_size=self.batch_size,
            silent=self.silent, process_title='SCUNet Restore',
        )
        output_ = np.clip(output_, a_min=0.0, a_max=1.0)
        return output_[0]


@lru_cache()
def _get_enhancer(model: SCUNetModelTyping = 'GAN', tile_size: int = 128, tile_overlap: int = 16,
                  batch_size: int = 4, silent: bool = False) -> _Enhancer:
    return _Enhancer(model, tile_size, tile_overlap, batch_size, silent)


def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN',
                        tile_size: int = 128, tile_overlap: int = 16, batch_size: int = 4,
                        silent: bool = False) -> Image.Image:
@@ -66,20 +97,4 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN',
    :return: The restored image.
    :rtype: Image.Image
    """
    image, alpha_mask = _rgba_preprocess(image)
    image = load_image(image, mode='RGB', force_background='white')
    input_ = np.array(image).astype(np.float32) / 255.0
    input_ = input_.transpose((2, 0, 1))[None, ...]

    def _method(ix):
        ox, = _open_scunet_model(model).run(['output'], {'input': ix})
        return ox

    output_ = area_batch_run(
        input_, _method,
        tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size,
        silent=silent, process_title='SCUNet Restore',
    )
    output_ = np.clip(output_, a_min=0.0, a_max=1.0)
    ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    return _rgba_postprocess(ret_image, alpha_mask)
    return _get_enhancer(model, tile_size, tile_overlap, batch_size, silent).process(image)

imgutils/restore/transparent.py

deleted100644 → 0
+0 −65
Original line number Diff line number Diff line
from typing import Optional

import numpy as np
from PIL import Image

from ..data.image import ImageTyping, load_image


def _has_alpha_channel(image: Image.Image) -> bool:
    """
    Check if the image has an alpha channel.

    :param image: The image to check.
    :type image: Image.Image

    :return: True if the image has an alpha channel, False otherwise.
    :rtype: bool
    """
    return any(band in {'A', 'a', 'P'} for band in image.getbands())


def _rgba_preprocess(image: ImageTyping):
    """
    Preprocess the image for RGBA conversion.

    :param image: The image to preprocess.
    :type image: ImageTyping

    :return: Preprocessed image and alpha mask.
    :rtype: Tuple[Image.Image, Optional[np.ndarray]]
    """
    image = load_image(image, force_background=None, mode=None)
    if _has_alpha_channel(image):
        image = image.convert('RGBA')
        pimage = image.convert('RGB')
        alpha_mask = np.array(image)[:, :, 3].astype(np.float32) / 255.0
    else:
        pimage = image.convert('RGB')
        alpha_mask = None

    return pimage, alpha_mask


def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None):
    """
    Postprocess the image after RGBA conversion.

    :param pimage: The processed image.
    :type pimage: Image.Image

    :param alpha_mask: The alpha mask.
    :type alpha_mask: Optional[np.ndarray]

    :return: Postprocessed image.
    :rtype: Image.Image
    """
    assert pimage.mode == 'RGB'
    if alpha_mask is None:
        return pimage
    else:
        alpha_channel = (alpha_mask * 255.0).astype(np.uint8)[..., np.newaxis]
        channels = np.array(pimage)
        rgba_channels = np.concatenate([channels, alpha_channel], axis=-1)
        assert rgba_channels.shape == (*channels.shape[:-1], 4)
        return Image.fromarray(rgba_channels, mode='RGBA')