Unverified Commit 902ab07f authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #116 from deepghs/dev/bg

dev(narugo): optimize background loading safefy
parents d39b4513 146f940d
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -18,3 +18,10 @@ load_images
.. autofunction:: load_images


has_alpha_channel
------------------------------

.. autofunction:: has_alpha_channel


+91 −7
Original line number Diff line number Diff line
"""
This module provides utility functions for image processing and manipulation using the PIL (Python Imaging Library) library.

It includes functions for loading images from various sources, handling multiple images, adding backgrounds to RGBA images,
and checking for alpha channels. The module is designed to simplify common image-related tasks in Python applications.

Key features:
- Loading images from different sources (file paths, binary data, file-like objects)
- Handling multiple images at once
- Adding backgrounds to RGBA images
- Checking for alpha channels in images

This module is particularly useful for applications that require image preprocessing or manipulation before further processing or analysis.
"""

from os import PathLike
from typing import Union, BinaryIO, List, Tuple, Optional

from PIL import Image

__all__ = [
    'ImageTyping', 'load_image',
    'MultiImagesTyping', 'load_images',
    'ImageTyping',
    'load_image',
    'MultiImagesTyping',
    'load_images',
    'add_background_for_rgba',
    'has_alpha_channel',
]


def _is_readable(obj):
    """
    Check if an object is readable (has 'read' and 'seek' methods).

    :param obj: The object to check for readability.
    :type obj: Any

    :return: True if the object is readable, False otherwise.
    :rtype: bool
    """
    return hasattr(obj, 'read') and hasattr(obj, 'seek')


@@ -18,8 +45,33 @@ ImageTyping = Union[str, PathLike, bytes, bytearray, BinaryIO, Image.Image]
MultiImagesTyping = Union[ImageTyping, List[ImageTyping], Tuple[ImageTyping, ...]]


def _has_alpha_channel(image: Image.Image) -> bool:
    return any(band in {'A', 'a', 'P'} for band in image.getbands())
def has_alpha_channel(image: Image.Image) -> bool:
    """
    Determine if the given Pillow image object has an alpha channel (transparency)

    :param image: Pillow image object
    :type image: Image.Image

    :return: Boolean, True if it has an alpha channel, False otherwise
    :rtype: bool
    """
    # Get the image mode
    mode = image.mode

    # Modes that directly include an alpha channel
    if mode in ('RGBA', 'LA', 'PA'):
        return True

    if getattr(image, 'palette'):
        # Check if there's a transparent palette
        try:
            image.palette.getcolor((0, 0, 0, 0))
            return True  # cannot find a line to trigger this
        except ValueError:
            pass

    # For other modes, check if 'transparency' key exists in image info
    return 'transparency' in image.info


def load_image(image: ImageTyping, mode=None, force_background: Optional[str] = 'white'):
@@ -43,6 +95,16 @@ def load_image(image: ImageTyping, mode=None, force_background: Optional[str] =

    :return: The loaded and transformed image.
    :rtype: Image.Image

    :raises TypeError: If the provided image type is not supported.

    :example:
    >>> from PIL import Image
    >>> img = load_image('path/to/image.png', mode='RGB', force_background='white')
    >>> isinstance(img, Image.Image)
    True
    >>> img.mode
    'RGB'
    """
    if isinstance(image, (str, PathLike, bytes, bytearray, BinaryIO)) or _is_readable(image):
        image = Image.open(image)
@@ -51,7 +113,7 @@ def load_image(image: ImageTyping, mode=None, force_background: Optional[str] =
    else:
        raise TypeError(f'Unknown image type - {image!r}.')

    if _has_alpha_channel(image) and force_background is not None:
    if has_alpha_channel(image) and force_background is not None:
        image = add_background_for_rgba(image, force_background)

    if mode is not None and image.mode != mode:
@@ -79,6 +141,14 @@ def load_images(images: MultiImagesTyping, mode=None, force_background: Optional

    :return: A list of loaded and transformed images.
    :rtype: List[Image.Image]

    :example:
    >>> img_paths = ['path/to/image1.png', 'path/to/image2.jpg']
    >>> loaded_images = load_images(img_paths, mode='RGB')
    >>> len(loaded_images)
    2
    >>> all(isinstance(img, Image.Image) for img in loaded_images)
    True
    """
    if not isinstance(images, (list, tuple)):
        images = [images]
@@ -102,6 +172,20 @@ def add_background_for_rgba(image: ImageTyping, background: str = 'white'):

    :return: The image with the added background, converted to RGB.
    :rtype: Image.Image

    :example:
    >>> from PIL import Image
    >>> rgba_image = Image.new('RGBA', (100, 100), (255, 0, 0, 128))
    >>> rgb_image = add_background_for_rgba(rgba_image, background='blue')
    >>> rgb_image.mode
    'RGB'
    """
    from .layer import istack
    return istack(background, image).convert('RGB')
    image = load_image(image, force_background=None, mode=None)
    try:
        ret_image = Image.new('RGBA', image.size, background)
        ret_image.paste(image, (0, 0), mask=image)
    except ValueError:
        ret_image = image
    if ret_image.mode != 'RGB':
        ret_image = ret_image.convert('RGB')
    return ret_image
+3 −16
Original line number Diff line number Diff line
@@ -5,26 +5,13 @@ Overview:
import numpy as np
from PIL import Image

from ..data import ImageTyping, load_image
from ..data import ImageTyping, load_image, has_alpha_channel

__all__ = [
    'ImageEnhancer',
]


def _has_alpha_channel(image: Image.Image) -> bool:
    """
    Check if the image has an alpha channel.

    :param image: The image to check.
    :type image: Image.Image

    :return: True if the image has an alpha channel, False otherwise.
    :rtype: bool
    """
    return any(band in {'A', 'a', 'P'} for band in image.getbands())


class ImageEnhancer:
    """
    Enhances images by applying various processing techniques.
@@ -103,10 +90,10 @@ class ImageEnhancer:
        :rtype: Image.Image
        """
        image = load_image(image, mode=None, force_background=None)
        mode = 'RGBA' if _has_alpha_channel(image) else 'RGB'
        mode = 'RGBA' if has_alpha_channel(image) else 'RGB'
        image = load_image(image, mode=mode, force_background=None)
        input_array = (np.array(image).astype(np.float32) / 255.0).transpose((2, 0, 1))
        if _has_alpha_channel(image):
        if has_alpha_channel(image):
            output_array = self._process_rgba(input_array)
        else:
            output_array = self._process_rgb(input_array)
+3 −7
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ from huggingface_hub import hf_hub_download

from .format import remove_underline
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..data import load_image, ImageTyping, has_alpha_channel
from ..utils import open_onnx_model, vreplace

SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
@@ -114,10 +114,6 @@ def _mcut_threshold(probs) -> float:
    return thresh


def _has_alpha_channel(image: Image.Image) -> bool:
    return any(band in {'A', 'a', 'P'} for band in image.getbands())


def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
    image = load_image(image, force_background=None, mode=None)
    image_shape = image.size
@@ -126,9 +122,9 @@ def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
    pad_top = (max_dim - image_shape[1]) // 2

    padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
    if _has_alpha_channel(image):
    try:
        padded_image.paste(image, (pad_left, pad_top), mask=image)
    else:
    except ValueError:
        padded_image.paste(image, (pad_left, pad_top))

    if max_dim != target_size:
+124 −1
Original line number Diff line number Diff line
import pytest
from PIL import Image

from imgutils.data import load_image
from imgutils.data import load_image, has_alpha_channel, add_background_for_rgba
from test.testings import get_testfile

_FILENAME = get_testfile('6125785.png')
@@ -23,3 +23,126 @@ class TestDataImage:
            assert load_image(image_, force_background=None) is image_
        else:
            assert image_diff(load_image(image_), result, throw_exception=False) < 1e-2

    @pytest.mark.parametrize(['color'], [
        ('white',),
        ('green',),
        ('red',),
        ('blue',),
        ('black',),
    ])
    def test_load_image_bg_rgba(self, image_diff, color):
        image = load_image(get_testfile('nian.png'), force_background=color, mode='RGB')
        expected = Image.open(get_testfile(f'nian_bg_{color}.png'))
        assert image_diff(image, expected, throw_exception=False) < 1e-2

    @pytest.mark.parametrize(['color'], [
        ('white',),
        ('green',),
        ('red',),
        ('blue',),
        ('black',),
    ])
    def test_add_background_for_rgba_rgba(self, image_diff, color):
        image = add_background_for_rgba(get_testfile('nian.png'), background=color)
        assert image.mode == 'RGB'
        expected = Image.open(get_testfile(f'nian_bg_{color}.png'))
        assert image_diff(image, expected, throw_exception=False) < 1e-2

    @pytest.mark.parametrize(['color'], [
        ('white',),
        ('green',),
        ('red',),
        ('blue',),
        ('black',),
    ])
    def test_load_image_bg_rgb(self, image_diff, color):
        image = load_image(get_testfile('mostima_post.jpg'), force_background=color, mode='RGB')
        expected = Image.open(get_testfile(f'mostima_post_bg_{color}.png'))
        assert image_diff(image, expected, throw_exception=False) < 1e-2

    @pytest.mark.parametrize(['color'], [
        ('white',),
        ('green',),
        ('red',),
        ('blue',),
        ('black',),
    ])
    def test_add_backround_for_rgba_rgb(self, image_diff, color):
        image = add_background_for_rgba(get_testfile('mostima_post.jpg'), background=color)
        assert image.mode == 'RGB'
        expected = Image.open(get_testfile(f'mostima_post_bg_{color}.png'))
        assert image_diff(image, expected, throw_exception=False) < 1e-2


@pytest.fixture
def rgba_image():
    img = Image.new('RGBA', (10, 10), (255, 0, 0, 128))
    return img


@pytest.fixture
def rgb_image():
    img = Image.new('RGB', (10, 10), (255, 0, 0))
    return img


@pytest.fixture
def la_image():
    img = Image.new('LA', (10, 10), (128, 128))
    return img


@pytest.fixture
def l_image():
    img = Image.new('L', (10, 10), 128)
    return img


@pytest.fixture
def p_image_with_transparency():
    width, height = 200, 200
    image = Image.new('P', (width, height))

    palette = []
    for i in range(256):
        palette.extend((i, i, i))  # 灰度调色板

    palette[:3] = (0, 0, 0)  # 黑色
    image.info['transparency'] = 0

    image.putpalette(palette)
    return image


@pytest.fixture
def p_image_without_transparency():
    img = Image.new('P', (10, 10))
    palette = [255, 0, 0, 255, 0, 0]  # No transparent color
    img.putpalette(palette)
    return img


@pytest.mark.unittest
class TestHasAlphaChannel:
    def test_rgba_image(self, rgba_image):
        assert has_alpha_channel(rgba_image)

    def test_rgb_image(self, rgb_image):
        assert not has_alpha_channel(rgb_image)

    def test_la_image(self, la_image):
        assert has_alpha_channel(la_image)

    def test_l_image(self, l_image):
        assert not has_alpha_channel(l_image)

    def test_p_image_with_transparency(self, p_image_with_transparency):
        assert has_alpha_channel(p_image_with_transparency)

    def test_p_image_without_transparency(self, p_image_without_transparency):
        assert not has_alpha_channel(p_image_without_transparency)

    def test_pa_image(self):
        pa_image = Image.new('PA', (10, 10))
        assert has_alpha_channel(pa_image)
Loading