Loading imgutils/restore/nafnet.py +10 −1 Original line number Diff line number Diff line Loading @@ -14,6 +14,12 @@ Overview: Currently, we've identified a significant issue with NafNet when images contain gaussian noise. To ensure your code functions correctly, please ensure the credibility of your image source or preprocess them using SCUNet. .. note:: New in version v0.4.4, **images with alpha channel supported**. If you use an image with alpha channel (e.g. RGBA images), it will return a RGBA image, otherwise return RGG image. """ from functools import lru_cache from typing import Literal Loading @@ -22,6 +28,7 @@ import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from .transparent import _rgba_preprocess, _rgba_postprocess from ..data import ImageTyping, load_image from ..utils import open_onnx_model, area_batch_run Loading Loading @@ -61,6 +68,7 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS', :return: The restored image. :rtype: Image.Image """ image, alpha_mask = _rgba_preprocess(image) image = load_image(image, mode='RGB', force_background='white') input_ = np.array(image).astype(np.float32) / 255.0 input_ = input_.transpose((2, 0, 1))[None, ...] Loading @@ -75,4 +83,5 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS', process_title='NafNet Restore', ) output_ = np.clip(output_, a_min=0.0, a_max=1.0) return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') return _rgba_postprocess(ret_image, alpha_mask) imgutils/restore/scunet.py +9 −1 Original line number Diff line number Diff line Loading @@ -10,6 +10,11 @@ Overview: .. image:: scunet_benchmark.plot.py.svg :align: center .. note:: New in version v0.4.4, **images with alpha channel supported**. If you use an image with alpha channel (e.g. RGBA images), it will return a RGBA image, otherwise return RGG image. """ from functools import lru_cache from typing import Literal Loading @@ -18,6 +23,7 @@ import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from .transparent import _rgba_preprocess, _rgba_postprocess from ..data import ImageTyping, load_image from ..utils import open_onnx_model, area_batch_run Loading Loading @@ -57,6 +63,7 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN', :return: The restored image. :rtype: Image.Image """ image, alpha_mask = _rgba_preprocess(image) image = load_image(image, mode='RGB', force_background='white') input_ = np.array(image).astype(np.float32) / 255.0 input_ = input_.transpose((2, 0, 1))[None, ...] Loading @@ -71,4 +78,5 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN', process_title='SCUNet Restore', ) output_ = np.clip(output_, a_min=0.0, a_max=1.0) return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') return _rgba_postprocess(ret_image, alpha_mask) imgutils/restore/transparent.py 0 → 100644 +65 −0 Original line number Diff line number Diff line from typing import Optional import numpy as np from PIL import Image from ..data.image import ImageTyping, load_image def _has_alpha_channel(image: Image.Image) -> bool: """ Check if the image has an alpha channel. :param image: The image to check. :type image: Image.Image :return: True if the image has an alpha channel, False otherwise. :rtype: bool """ return any(band in {'A', 'a', 'P'} for band in image.getbands()) def _rgba_preprocess(image: ImageTyping): """ Preprocess the image for RGBA conversion. :param image: The image to preprocess. :type image: ImageTyping :return: Preprocessed image and alpha mask. :rtype: Tuple[Image.Image, Optional[np.ndarray]] """ image = load_image(image, force_background=None, mode=None) if _has_alpha_channel(image): image = image.convert('RGBA') pimage = image.convert('RGB') alpha_mask = np.array(image)[:, :, 3].astype(np.float32) / 255.0 else: pimage = image.convert('RGB') alpha_mask = None return pimage, alpha_mask def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None): """ Postprocess the image after RGBA conversion. :param pimage: The processed image. :type pimage: Image.Image :param alpha_mask: The alpha mask. :type alpha_mask: Optional[np.ndarray] :return: Postprocessed image. :rtype: Image.Image """ assert pimage.mode == 'RGB' if alpha_mask is None: return pimage else: alpha_channel = (alpha_mask * 255.0).astype(np.uint8)[..., np.newaxis] channels = np.array(pimage) rgba_channels = np.concatenate([channels, alpha_channel], axis=-1) assert rgba_channels.shape == (*channels.shape[:-1], 4) return Image.fromarray(rgba_channels, mode='RGBA') test/restore/conftest.py +5 −0 Original line number Diff line number Diff line Loading @@ -39,3 +39,8 @@ def q45_image(sample_image): img_file = os.path.join(td, 'image.jpg') sample_image.save(img_file, quality=45) yield load_image(img_file) @pytest.fixture() def rgba_image(): yield load_image(get_testfile('rgba_restore.png'), mode='RGBA', force_background=None) test/restore/test_nafnet.py +7 −0 Original line number Diff line number Diff line import pytest from imgutils.data import grid_transparent from imgutils.metrics import psnr from imgutils.restore import restore_with_nafnet from imgutils.restore.nafnet import _open_nafnet_model Loading @@ -20,3 +21,9 @@ class TestRestoreNafNet: def test_restore_with_nafnet_q45(self, q45_image, clear_image): assert psnr(restore_with_nafnet(q45_image), clear_image) >= 40.0 def test_restore_with_nafnet_rgba(self, rgba_image): assert rgba_image.mode == 'RGBA' restored_image = restore_with_nafnet(rgba_image) assert restored_image.mode == 'RGBA' assert psnr(grid_transparent(restored_image), grid_transparent(rgba_image), ) >= 35 Loading
imgutils/restore/nafnet.py +10 −1 Original line number Diff line number Diff line Loading @@ -14,6 +14,12 @@ Overview: Currently, we've identified a significant issue with NafNet when images contain gaussian noise. To ensure your code functions correctly, please ensure the credibility of your image source or preprocess them using SCUNet. .. note:: New in version v0.4.4, **images with alpha channel supported**. If you use an image with alpha channel (e.g. RGBA images), it will return a RGBA image, otherwise return RGG image. """ from functools import lru_cache from typing import Literal Loading @@ -22,6 +28,7 @@ import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from .transparent import _rgba_preprocess, _rgba_postprocess from ..data import ImageTyping, load_image from ..utils import open_onnx_model, area_batch_run Loading Loading @@ -61,6 +68,7 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS', :return: The restored image. :rtype: Image.Image """ image, alpha_mask = _rgba_preprocess(image) image = load_image(image, mode='RGB', force_background='white') input_ = np.array(image).astype(np.float32) / 255.0 input_ = input_.transpose((2, 0, 1))[None, ...] Loading @@ -75,4 +83,5 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS', process_title='NafNet Restore', ) output_ = np.clip(output_, a_min=0.0, a_max=1.0) return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') return _rgba_postprocess(ret_image, alpha_mask)
imgutils/restore/scunet.py +9 −1 Original line number Diff line number Diff line Loading @@ -10,6 +10,11 @@ Overview: .. image:: scunet_benchmark.plot.py.svg :align: center .. note:: New in version v0.4.4, **images with alpha channel supported**. If you use an image with alpha channel (e.g. RGBA images), it will return a RGBA image, otherwise return RGG image. """ from functools import lru_cache from typing import Literal Loading @@ -18,6 +23,7 @@ import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from .transparent import _rgba_preprocess, _rgba_postprocess from ..data import ImageTyping, load_image from ..utils import open_onnx_model, area_batch_run Loading Loading @@ -57,6 +63,7 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN', :return: The restored image. :rtype: Image.Image """ image, alpha_mask = _rgba_preprocess(image) image = load_image(image, mode='RGB', force_background='white') input_ = np.array(image).astype(np.float32) / 255.0 input_ = input_.transpose((2, 0, 1))[None, ...] Loading @@ -71,4 +78,5 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN', process_title='SCUNet Restore', ) output_ = np.clip(output_, a_min=0.0, a_max=1.0) return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') return _rgba_postprocess(ret_image, alpha_mask)
imgutils/restore/transparent.py 0 → 100644 +65 −0 Original line number Diff line number Diff line from typing import Optional import numpy as np from PIL import Image from ..data.image import ImageTyping, load_image def _has_alpha_channel(image: Image.Image) -> bool: """ Check if the image has an alpha channel. :param image: The image to check. :type image: Image.Image :return: True if the image has an alpha channel, False otherwise. :rtype: bool """ return any(band in {'A', 'a', 'P'} for band in image.getbands()) def _rgba_preprocess(image: ImageTyping): """ Preprocess the image for RGBA conversion. :param image: The image to preprocess. :type image: ImageTyping :return: Preprocessed image and alpha mask. :rtype: Tuple[Image.Image, Optional[np.ndarray]] """ image = load_image(image, force_background=None, mode=None) if _has_alpha_channel(image): image = image.convert('RGBA') pimage = image.convert('RGB') alpha_mask = np.array(image)[:, :, 3].astype(np.float32) / 255.0 else: pimage = image.convert('RGB') alpha_mask = None return pimage, alpha_mask def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None): """ Postprocess the image after RGBA conversion. :param pimage: The processed image. :type pimage: Image.Image :param alpha_mask: The alpha mask. :type alpha_mask: Optional[np.ndarray] :return: Postprocessed image. :rtype: Image.Image """ assert pimage.mode == 'RGB' if alpha_mask is None: return pimage else: alpha_channel = (alpha_mask * 255.0).astype(np.uint8)[..., np.newaxis] channels = np.array(pimage) rgba_channels = np.concatenate([channels, alpha_channel], axis=-1) assert rgba_channels.shape == (*channels.shape[:-1], 4) return Image.fromarray(rgba_channels, mode='RGBA')
test/restore/conftest.py +5 −0 Original line number Diff line number Diff line Loading @@ -39,3 +39,8 @@ def q45_image(sample_image): img_file = os.path.join(td, 'image.jpg') sample_image.save(img_file, quality=45) yield load_image(img_file) @pytest.fixture() def rgba_image(): yield load_image(get_testfile('rgba_restore.png'), mode='RGBA', force_background=None)
test/restore/test_nafnet.py +7 −0 Original line number Diff line number Diff line import pytest from imgutils.data import grid_transparent from imgutils.metrics import psnr from imgutils.restore import restore_with_nafnet from imgutils.restore.nafnet import _open_nafnet_model Loading @@ -20,3 +21,9 @@ class TestRestoreNafNet: def test_restore_with_nafnet_q45(self, q45_image, clear_image): assert psnr(restore_with_nafnet(q45_image), clear_image) >= 40.0 def test_restore_with_nafnet_rgba(self, rgba_image): assert rgba_image.mode == 'RGBA' restored_image = restore_with_nafnet(rgba_image) assert restored_image.mode == 'RGBA' assert psnr(grid_transparent(restored_image), grid_transparent(rgba_image), ) >= 35