Commit 6eb7c416 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add has_alpha_channel

parent d39b4513
Loading
Loading
Loading
Loading
+30 −5
Original line number Diff line number Diff line
@@ -4,9 +4,12 @@ from typing import Union, BinaryIO, List, Tuple, Optional
from PIL import Image

__all__ = [
    'ImageTyping', 'load_image',
    'MultiImagesTyping', 'load_images',
    'ImageTyping',
    'load_image',
    'MultiImagesTyping',
    'load_images',
    'add_background_for_rgba',
    'has_alpha_channel',
]


@@ -18,8 +21,30 @@ ImageTyping = Union[str, PathLike, bytes, bytearray, BinaryIO, Image.Image]
MultiImagesTyping = Union[ImageTyping, List[ImageTyping], Tuple[ImageTyping, ...]]


def _has_alpha_channel(image: Image.Image) -> bool:
    return any(band in {'A', 'a', 'P'} for band in image.getbands())
def has_alpha_channel(image: Image.Image) -> bool:
    """
    Determine if the given Pillow image object has an alpha channel (transparency)

    :param image: Pillow image object
    :return: Boolean, True if it has an alpha channel, False otherwise
    """
    # Get the image mode
    mode = image.mode

    # Modes that directly include an alpha channel
    if mode in ('RGBA', 'LA', 'PA'):
        return True

    if getattr(image, 'palette'):
        # Check if there's a transparent palette
        try:
            image.palette.getcolor((0, 0, 0, 0))
            return True  # cannot find a line to trigger this
        except ValueError:
            pass

    # For other modes, check if 'transparency' key exists in image info
    return 'transparency' in image.info


def load_image(image: ImageTyping, mode=None, force_background: Optional[str] = 'white'):
@@ -51,7 +76,7 @@ def load_image(image: ImageTyping, mode=None, force_background: Optional[str] =
    else:
        raise TypeError(f'Unknown image type - {image!r}.')

    if _has_alpha_channel(image) and force_background is not None:
    if has_alpha_channel(image) and force_background is not None:
        image = add_background_for_rgba(image, force_background)

    if mode is not None and image.mode != mode:
+3 −16
Original line number Diff line number Diff line
@@ -5,26 +5,13 @@ Overview:
import numpy as np
from PIL import Image

from ..data import ImageTyping, load_image
from ..data import ImageTyping, load_image, has_alpha_channel

__all__ = [
    'ImageEnhancer',
]


def _has_alpha_channel(image: Image.Image) -> bool:
    """
    Check if the image has an alpha channel.

    :param image: The image to check.
    :type image: Image.Image

    :return: True if the image has an alpha channel, False otherwise.
    :rtype: bool
    """
    return any(band in {'A', 'a', 'P'} for band in image.getbands())


class ImageEnhancer:
    """
    Enhances images by applying various processing techniques.
@@ -103,10 +90,10 @@ class ImageEnhancer:
        :rtype: Image.Image
        """
        image = load_image(image, mode=None, force_background=None)
        mode = 'RGBA' if _has_alpha_channel(image) else 'RGB'
        mode = 'RGBA' if has_alpha_channel(image) else 'RGB'
        image = load_image(image, mode=mode, force_background=None)
        input_array = (np.array(image).astype(np.float32) / 255.0).transpose((2, 0, 1))
        if _has_alpha_channel(image):
        if has_alpha_channel(image):
            output_array = self._process_rgba(input_array)
        else:
            output_array = self._process_rgb(input_array)
+2 −6
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ from huggingface_hub import hf_hub_download

from .format import remove_underline
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..data import load_image, ImageTyping, has_alpha_channel
from ..utils import open_onnx_model, vreplace

SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
@@ -114,10 +114,6 @@ def _mcut_threshold(probs) -> float:
    return thresh


def _has_alpha_channel(image: Image.Image) -> bool:
    return any(band in {'A', 'a', 'P'} for band in image.getbands())


def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
    image = load_image(image, force_background=None, mode=None)
    image_shape = image.size
@@ -126,7 +122,7 @@ def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
    pad_top = (max_dim - image_shape[1]) // 2

    padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
    if _has_alpha_channel(image):
    if has_alpha_channel(image):
        padded_image.paste(image, (pad_left, pad_top), mask=image)
    else:
        padded_image.paste(image, (pad_left, pad_top))
+74 −1
Original line number Diff line number Diff line
import pytest
from PIL import Image

from imgutils.data import load_image
from imgutils.data import load_image, has_alpha_channel
from test.testings import get_testfile

_FILENAME = get_testfile('6125785.png')
@@ -23,3 +23,76 @@ class TestDataImage:
            assert load_image(image_, force_background=None) is image_
        else:
            assert image_diff(load_image(image_), result, throw_exception=False) < 1e-2


@pytest.fixture
def rgba_image():
    img = Image.new('RGBA', (10, 10), (255, 0, 0, 128))
    return img


@pytest.fixture
def rgb_image():
    img = Image.new('RGB', (10, 10), (255, 0, 0))
    return img


@pytest.fixture
def la_image():
    img = Image.new('LA', (10, 10), (128, 128))
    return img


@pytest.fixture
def l_image():
    img = Image.new('L', (10, 10), 128)
    return img


@pytest.fixture
def p_image_with_transparency():
    width, height = 200, 200
    image = Image.new('P', (width, height))

    palette = []
    for i in range(256):
        palette.extend((i, i, i))  # 灰度调色板

    palette[:3] = (0, 0, 0)  # 黑色
    image.info['transparency'] = 0

    image.putpalette(palette)
    return image


@pytest.fixture
def p_image_without_transparency():
    img = Image.new('P', (10, 10))
    palette = [255, 0, 0, 255, 0, 0]  # No transparent color
    img.putpalette(palette)
    return img


@pytest.mark.unittest
class TestHasAlphaChannel:
    def test_rgba_image(self, rgba_image):
        assert has_alpha_channel(rgba_image)

    def test_rgb_image(self, rgb_image):
        assert not has_alpha_channel(rgb_image)

    def test_la_image(self, la_image):
        assert has_alpha_channel(la_image)

    def test_l_image(self, l_image):
        assert not has_alpha_channel(l_image)

    def test_p_image_with_transparency(self, p_image_with_transparency):
        assert has_alpha_channel(p_image_with_transparency)

    def test_p_image_without_transparency(self, p_image_without_transparency):
        assert not has_alpha_channel(p_image_without_transparency)

    def test_pa_image(self):
        pa_image = Image.new('PA', (10, 10))
        assert has_alpha_channel(pa_image)