Commit 159dce69 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): save these updates

parent f0354f40
Loading
Loading
Loading
Loading
+44 −6
Original line number Diff line number Diff line
@@ -269,6 +269,45 @@ def parse_sdmeta_from_text(x: str) -> SDMetaData:
    return SDMetaData(prompt, neg_prompt, params)


class _InvalidSDMetaError(Exception):
    pass


def _sdtext_validate(text: str):
    from .nai import _naimeta_text_validate, _InvalidNAIMetaError

    try:
        _naimeta_text_validate(text)
    except _InvalidNAIMetaError:
        pass
    else:
        raise _InvalidSDMetaError

    if text:
        return text
    else:
        raise _InvalidSDMetaError


def _get_raw_sdtext(image: ImageTyping) -> Optional[str]:
    image = load_image(image, force_background=None, mode=None)

    try:
        return _sdtext_validate(read_geninfo_parameters(image))
    except _InvalidSDMetaError:
        pass

    try:
        return _sdtext_validate(read_geninfo_exif(image))
    except _InvalidSDMetaError:
        pass

    try:
        return _sdtext_validate(read_geninfo_gif(image))
    except _InvalidSDMetaError:
        raise _InvalidSDMetaError


def get_sdmeta_from_image(image: ImageTyping) -> Optional[SDMetaData]:
    """
    Extract and parse Stable Diffusion metadata from an image.
@@ -297,13 +336,12 @@ def get_sdmeta_from_image(image: ImageTyping) -> Optional[SDMetaData]:
        ...     print("No SD metadata found in the image.")
    """
    image = load_image(image, mode=None, force_background=None)
    pnginfo_text = (read_geninfo_parameters(image) or
                    read_geninfo_exif(image) or
                    read_geninfo_gif(image))
    if pnginfo_text:
        return parse_sdmeta_from_text(pnginfo_text)
    else:
    try:
        pnginfo_text = _get_raw_sdtext(image)
    except _InvalidSDMetaError:
        return None
    else:
        return parse_sdmeta_from_text(pnginfo_text)


def _save_png_with_sdmeta(image: Image.Image, dst_file: Union[str, os.PathLike], metadata: SDMetaData, **kwargs):
+25 −6
Original line number Diff line number Diff line
@@ -116,11 +116,30 @@ def _naimeta_validate(data):
    :raises _InvalidNAIMetaError: If the metadata is invalid.
    """
    if isinstance(data, dict) and data.get('Software') and data.get('Source') and data.get('Comment'):
        try:
            json.loads(data['Comment'])
        except (TypeError, json.JSONDecodeError):
            raise _InvalidNAIMetaError

        if data.get('Generation time'):
            try:
                _ = float(data['Generation time'])
            except (TypeError, ValueError):
                raise _InvalidNAIMetaError

        return data

    else:
        raise _InvalidNAIMetaError


def _naimeta_text_validate(data):
    try:
        return _naimeta_validate(json.loads(data))
    except (TypeError, json.JSONDecodeError):
        raise _InvalidNAIMetaError


def _get_naimeta_raw(image: ImageTyping) -> dict:
    """
    Extract raw NAI metadata from an image.
@@ -146,18 +165,18 @@ def _get_naimeta_raw(image: ImageTyping) -> dict:
        pass

    try:
        return _naimeta_validate(json.loads(read_geninfo_parameters(image)))
    except (TypeError, json.JSONDecodeError, _InvalidNAIMetaError):
        return _naimeta_text_validate(read_geninfo_parameters(image))
    except _InvalidNAIMetaError:
        pass

    try:
        return _naimeta_validate(json.loads(read_geninfo_exif(image)))
    except (TypeError, json.JSONDecodeError, _InvalidNAIMetaError):
        return _naimeta_text_validate(read_geninfo_exif(image))
    except _InvalidNAIMetaError:
        pass

    try:
        return _naimeta_validate(json.loads(read_geninfo_gif(image)))
    except (TypeError, json.JSONDecodeError, _InvalidNAIMetaError):
        return _naimeta_text_validate(read_geninfo_gif(image))
    except _InvalidNAIMetaError:
        raise _InvalidNAIMetaError


+13 −0
Original line number Diff line number Diff line
@@ -11,6 +11,11 @@ from imgutils.sd import get_sdmeta_from_image, SDMetaData, parse_sdmeta_from_tex
from test.testings import get_testfile


@pytest.fixture()
def nai3_file():
    return get_testfile('nai3.png')


@pytest.fixture()
def clean_image():
    return get_testfile('nai3_clear.png')
@@ -424,3 +429,11 @@ Steps: 20, Sampler: DPM++ 2M SDE Karras, CFG scale: 7, Seed: 2647703743, Size: 7
            else:
                save_image_with_sdmeta(clean_image, f'image{ext}', metadata=sdimg_4_std)
                assert get_sdmeta_from_image(f'image{ext}') == sdimg_4_std

    @pytest.mark.parametrize(['file'], [
        ('nai3.png',),
        ('nai3_clear.png',),
        ('nai3_info_rgb.png',),
    ])
    def test_clean_image(self, file):
        assert get_sdmeta_from_image(get_testfile(file)) is None