Commit b13596a0 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add rgba for restore models

parent 8d17334e
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from .transparent import _rgba_preprocess, _rgba_postprocess
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model, area_batch_run

@@ -61,6 +62,7 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS',
    :return: The restored image.
    :rtype: Image.Image
    """
    image, alpha_mask = _rgba_preprocess(image)
    image = load_image(image, mode='RGB', force_background='white')
    input_ = np.array(image).astype(np.float32) / 255.0
    input_ = input_.transpose((2, 0, 1))[None, ...]
@@ -75,4 +77,5 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS',
        process_title='NafNet Restore',
    )
    output_ = np.clip(output_, a_min=0.0, a_max=1.0)
    return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    return _rgba_postprocess(ret_image, alpha_mask)
+4 −1
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from .transparent import _rgba_preprocess, _rgba_postprocess
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model, area_batch_run

@@ -57,6 +58,7 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN',
    :return: The restored image.
    :rtype: Image.Image
    """
    image, alpha_mask = _rgba_preprocess(image)
    image = load_image(image, mode='RGB', force_background='white')
    input_ = np.array(image).astype(np.float32) / 255.0
    input_ = input_.transpose((2, 0, 1))[None, ...]
@@ -71,4 +73,5 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN',
        process_title='SCUNet Restore',
    )
    output_ = np.clip(output_, a_min=0.0, a_max=1.0)
    return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    return _rgba_postprocess(ret_image, alpha_mask)
+35 −0
Original line number Diff line number Diff line
from typing import Optional

import numpy as np
from PIL import Image

from ..data.image import ImageTyping, load_image


def _has_alpha_channel(image: Image.Image) -> bool:
    return any(band in {'A', 'a', 'P'} for band in image.getbands())


def _rgba_preprocess(image: ImageTyping):
    image = load_image(image, force_background=None, mode=None)
    if _has_alpha_channel(image):
        image = image.convert('RGBA')
        pimage = image.convert('RGB')
        alpha_mask = np.array(image)[:, :, 3].astype(np.float32) / 255.0
    else:
        pimage = image.convert('RGB')
        alpha_mask = None

    return pimage, alpha_mask


def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None):
    assert pimage.mode == 'RGB'
    if alpha_mask is None:
        return pimage
    else:
        alpha_channel = (alpha_mask * 255.0).astype(np.uint8)[..., np.newaxis]
        channels = np.array(pimage)
        rgba_channels = np.concatenate([channels, alpha_channel], axis=-1)
        assert rgba_channels.shape == (*channels.shape[:-1], 4)
        return Image.fromarray(rgba_channels, mode='RGBA')