Unverified Commit 28db3872 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #100 from deepghs/fix/nai

dev(narugo): fix nai decode error
parents 51fe9083 6218299e
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -164,7 +164,10 @@ class ImageLsbDataExtractor(object):
            raise ValueError(f'Image magic number mismatch, '
                             f'{self._magic_bytes!r} expected but {read_magic!r}.')

        read_len = reader.read_32bit_integer() // 8
        next_int = reader.read_32bit_integer()
        if next_int is None:
            raise ValueError('No next int32 to read.')
        read_len = next_int // 8
        json_data = reader.get_next_n_bytes(read_len)

        json_data = json.loads(gzip.decompress(json_data).decode("utf-8"))
+6 −1
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ This module is particularly useful for working with AI-generated images and thei
import json
import os
import warnings
import zlib
from dataclasses import dataclass
from typing import Optional, Union

@@ -95,7 +96,11 @@ def _get_naimeta_raw(image: ImageTyping) -> dict:
    image = load_image(image, force_background=None, mode=None)
    try:
        return ImageLsbDataExtractor().extract_data(image)
    except (ValueError, json.JSONDecodeError):
    except (ValueError, json.JSONDecodeError, zlib.error, OSError, UnicodeDecodeError):
        # ValueError: binary data with wrong format
        # json.JSONDecodeError: zot a json-formatted data
        # zlib.error, OSError: not zlib compressed binary data
        # UnicodeDecodeError: cannot decode as utf-8 text
        return image.info or {}


+7 −0
Original line number Diff line number Diff line
@@ -205,3 +205,10 @@ class TestSDNai:
                    save_pnginfo=False, add_lsb_meta=False,
                )
            assert get_naimeta_from_image('image.png') is None

    @pytest.mark.parametrize(['file'], [
        ('118519492_p0.png',),
        ('118438300_p1.png',),
    ])
    def test_image_error_with_wrong_format(self, file):
        assert get_naimeta_from_image(get_testfile(file)) is None
+1.03 MiB
Loading image diff...
+1.33 MiB
Loading image diff...