Unverified Commit 5de99ab1 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #87 from deepghs/dev/restore

dev(narugo): add rgba support for restore models
parents 8d17334e d1ea273a
Loading
Loading
Loading
Loading
+10 −1
Original line number Diff line number Diff line
@@ -14,6 +14,12 @@ Overview:
        Currently, we've identified a significant issue with NafNet when images contain gaussian noise.
        To ensure your code functions correctly, please ensure the credibility of
        your image source or preprocess them using SCUNet.

    .. note::
        New in version v0.4.4, **images with alpha channel supported**.

        If you use an image with alpha channel (e.g. RGBA images),
        it will return a RGBA image, otherwise return RGG image.
"""
from functools import lru_cache
from typing import Literal
@@ -22,6 +28,7 @@ import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from .transparent import _rgba_preprocess, _rgba_postprocess
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model, area_batch_run

@@ -61,6 +68,7 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS',
    :return: The restored image.
    :rtype: Image.Image
    """
    image, alpha_mask = _rgba_preprocess(image)
    image = load_image(image, mode='RGB', force_background='white')
    input_ = np.array(image).astype(np.float32) / 255.0
    input_ = input_.transpose((2, 0, 1))[None, ...]
@@ -75,4 +83,5 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS',
        process_title='NafNet Restore',
    )
    output_ = np.clip(output_, a_min=0.0, a_max=1.0)
    return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    return _rgba_postprocess(ret_image, alpha_mask)
+9 −1
Original line number Diff line number Diff line
@@ -10,6 +10,11 @@ Overview:
    .. image:: scunet_benchmark.plot.py.svg
        :align: center

    .. note::
        New in version v0.4.4, **images with alpha channel supported**.

        If you use an image with alpha channel (e.g. RGBA images),
        it will return a RGBA image, otherwise return RGG image.
"""
from functools import lru_cache
from typing import Literal
@@ -18,6 +23,7 @@ import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from .transparent import _rgba_preprocess, _rgba_postprocess
from ..data import ImageTyping, load_image
from ..utils import open_onnx_model, area_batch_run

@@ -57,6 +63,7 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN',
    :return: The restored image.
    :rtype: Image.Image
    """
    image, alpha_mask = _rgba_preprocess(image)
    image = load_image(image, mode='RGB', force_background='white')
    input_ = np.array(image).astype(np.float32) / 255.0
    input_ = input_.transpose((2, 0, 1))[None, ...]
@@ -71,4 +78,5 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN',
        process_title='SCUNet Restore',
    )
    output_ = np.clip(output_, a_min=0.0, a_max=1.0)
    return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    return _rgba_postprocess(ret_image, alpha_mask)
+65 −0
Original line number Diff line number Diff line
from typing import Optional

import numpy as np
from PIL import Image

from ..data.image import ImageTyping, load_image


def _has_alpha_channel(image: Image.Image) -> bool:
    """
    Check if the image has an alpha channel.

    :param image: The image to check.
    :type image: Image.Image

    :return: True if the image has an alpha channel, False otherwise.
    :rtype: bool
    """
    return any(band in {'A', 'a', 'P'} for band in image.getbands())


def _rgba_preprocess(image: ImageTyping):
    """
    Preprocess the image for RGBA conversion.

    :param image: The image to preprocess.
    :type image: ImageTyping

    :return: Preprocessed image and alpha mask.
    :rtype: Tuple[Image.Image, Optional[np.ndarray]]
    """
    image = load_image(image, force_background=None, mode=None)
    if _has_alpha_channel(image):
        image = image.convert('RGBA')
        pimage = image.convert('RGB')
        alpha_mask = np.array(image)[:, :, 3].astype(np.float32) / 255.0
    else:
        pimage = image.convert('RGB')
        alpha_mask = None

    return pimage, alpha_mask


def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None):
    """
    Postprocess the image after RGBA conversion.

    :param pimage: The processed image.
    :type pimage: Image.Image

    :param alpha_mask: The alpha mask.
    :type alpha_mask: Optional[np.ndarray]

    :return: Postprocessed image.
    :rtype: Image.Image
    """
    assert pimage.mode == 'RGB'
    if alpha_mask is None:
        return pimage
    else:
        alpha_channel = (alpha_mask * 255.0).astype(np.uint8)[..., np.newaxis]
        channels = np.array(pimage)
        rgba_channels = np.concatenate([channels, alpha_channel], axis=-1)
        assert rgba_channels.shape == (*channels.shape[:-1], 4)
        return Image.fromarray(rgba_channels, mode='RGBA')
+5 −0
Original line number Diff line number Diff line
@@ -39,3 +39,8 @@ def q45_image(sample_image):
        img_file = os.path.join(td, 'image.jpg')
        sample_image.save(img_file, quality=45)
        yield load_image(img_file)


@pytest.fixture()
def rgba_image():
    yield load_image(get_testfile('rgba_restore.png'), mode='RGBA', force_background=None)
+7 −0
Original line number Diff line number Diff line
import pytest

from imgutils.data import grid_transparent
from imgutils.metrics import psnr
from imgutils.restore import restore_with_nafnet
from imgutils.restore.nafnet import _open_nafnet_model
@@ -20,3 +21,9 @@ class TestRestoreNafNet:

    def test_restore_with_nafnet_q45(self, q45_image, clear_image):
        assert psnr(restore_with_nafnet(q45_image), clear_image) >= 40.0

    def test_restore_with_nafnet_rgba(self, rgba_image):
        assert rgba_image.mode == 'RGBA'
        restored_image = restore_with_nafnet(rgba_image)
        assert restored_image.mode == 'RGBA'
        assert psnr(grid_transparent(restored_image), grid_transparent(rgba_image), ) >= 35
Loading