Commit 577cbce3 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add nai metadata support

parent 34af2007
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -4,3 +4,4 @@ Overview:
"""
from .metadata import parse_sdmeta_from_text, get_sdmeta_from_image, SDMetaData
from .model import read_metadata, save_with_metadata
from .nai import get_naimeta_from_image, NAIMetadata
+3 −0
Original line number Diff line number Diff line
from .extract import LSBExtractor, ImageLsbDataExtractor
from .inject import serialize_metadata, inject_data
from .metadata import get_naimeta_from_image, NAIMetadata, add_naimeta_to_image, save_image_with_naimeta
+75 −0
Original line number Diff line number Diff line
import gzip
import json

import numpy as np
from PIL import Image


# MIT: https://github.com/NovelAI/novelai-image-metadata/blob/main/nai_meta.py
class LSBExtractor(object):
    def __init__(self, data: np.ndarray):
        self.data = data
        self.rows, self.cols, self.dim = data.shape
        self.bits = 0
        self.byte = 0
        self.row = 0
        self.col = 0

    def _extract_next_bit(self):
        if self.row < self.rows and self.col < self.cols:
            bit = self.data[self.row, self.col, self.dim - 1] & 1
            self.bits += 1
            self.byte <<= 1
            self.byte |= bit
            self.row += 1
            if self.row == self.rows:
                self.row = 0
                self.col += 1

    def get_one_byte(self):
        while self.bits < 8:
            self._extract_next_bit()
        byte = bytearray([self.byte])
        self.bits = 0
        self.byte = 0
        return byte

    def get_next_n_bytes(self, n):
        bytes_list = bytearray()
        for _ in range(n):
            byte = self.get_one_byte()
            if not byte:
                break
            bytes_list.extend(byte)
        return bytes_list

    def read_32bit_integer(self):
        bytes_list = self.get_next_n_bytes(4)
        if len(bytes_list) == 4:
            integer_value = int.from_bytes(bytes_list, byteorder='big')
            return integer_value
        else:
            return None


# MIT: https://github.com/NovelAI/novelai-image-metadata/blob/main/nai_meta.py
class ImageLsbDataExtractor(object):
    def __init__(self, magic: str = "stealth_pngcomp"):
        self._magic_bytes = magic.encode('utf-8')

    def extract_data(self, image: Image.Image) -> dict:
        if image.mode != 'RGBA':
            raise ValueError(f'Image should be in RGBA mode, but {image.mode!r} found.')
        image = np.array(image)
        reader = LSBExtractor(image)

        read_magic = reader.get_next_n_bytes(len(self._magic_bytes))
        if not (self._magic_bytes == read_magic):
            raise ValueError(f'Image magic number mismatch, '
                             f'{self._magic_bytes!r} expected but {read_magic!r}.')

        read_len = reader.read_32bit_integer() // 8
        json_data = reader.get_next_n_bytes(read_len)

        json_data = json.loads(gzip.decompress(json_data).decode("utf-8"))
        return json_data
+138 −0
Original line number Diff line number Diff line
# MIT: https://github.com/NovelAI/novelai-image-metadata/blob/main/nai_meta_writer.py
import gzip
import json

# BCH error correction
import bchlib
import numpy as np
from PIL import Image
from PIL.PngImagePlugin import PngInfo

correctable_bits = 16
block_length = 2019
code_block_len = 1920


def bit_shuffle(data_bytes, w, h, use_bytes=False):
    bits = np.frombuffer(data_bytes, dtype=np.uint8)
    bit_fac = 8
    if use_bytes:
        bit_fac = 1
    else:
        bits = np.unpackbits(bits)
    bits = bits.reshape((h, w, 3 * bit_fac))
    code_block_len = 1920
    flat_tile_len = (w * h * 3) // code_block_len
    tile_w = 32
    if flat_tile_len // tile_w > 100:
        tile_w = 64
    tile_h = flat_tile_len // tile_w
    h_cutoff = (h // tile_h) * tile_h
    tile_hr = h - h_cutoff
    easy_tiles = bits[:h_cutoff].reshape(h_cutoff // tile_h, tile_h, w // tile_w, tile_w, 3 * bit_fac)
    easy_tiles = easy_tiles.swapaxes(1, 2)
    easy_tiles = easy_tiles.reshape(-1, tile_h * tile_w)
    easy_tiles = easy_tiles.T
    rest_tiles = bits[h_cutoff:]
    rest_tiles = rest_tiles.reshape(tile_hr, 1, w // tile_w, tile_w, 3 * bit_fac)
    rest_tiles = rest_tiles.swapaxes(1, 2)
    rest_tiles = rest_tiles.reshape(-1, tile_hr * tile_w)
    rest_tiles = rest_tiles.T
    rest_dim = rest_tiles.shape[-1]
    rest_tiles = np.pad(rest_tiles, ((0, 0), (0, easy_tiles.shape[-1] - rest_tiles.shape[-1])), mode='constant',
                        constant_values=0)
    bits = np.concatenate((easy_tiles, rest_tiles), axis=0)
    dim = bits.shape[-1]
    bits = bits.reshape((-1,))
    if not use_bytes:
        bits = np.packbits(bits)
    return bytearray(bits.tobytes()), dim, rest_tiles.shape[0], rest_dim


def split_byte_ranges(data_bytes, n, w, h):
    data_bytes, dim, rest_size, rest_dim = bit_shuffle(data_bytes.copy(), w, h, use_bytes=True)
    chunks = []
    for i in range(0, len(data_bytes), n):
        chunks.append(data_bytes[i:i + n])
    return chunks, dim, rest_size, rest_dim


def pad(data_bytes):
    return bytearray(data_bytes + b'\x00' * (2019 - len(data_bytes)))


# Returns codes for the data in data_bytes
def fec_encode(data_bytes, w, h):
    encoder = bchlib.BCH(16, prim_poly=17475)
    # import galois
    # encoder = galois.BCH(16383, 16383-224, d=17, c=224)
    chunks = [bytearray(encoder.encode(pad(x))) for x in split_byte_ranges(data_bytes, 2019, w, h)[0]]
    return b''.join(chunks)


class LSBInjector:
    def __init__(self, data):
        self.data = data
        self.buffer = bytearray()

    def put_byte(self, byte):
        self.buffer.append(byte)

    def put_32bit_integer(self, integer_value):
        self.buffer.extend(integer_value.to_bytes(4, byteorder='big'))

    def put_bytes(self, bytes_list):
        self.buffer.extend(bytes_list)

    def put_string(self, string):
        self.put_bytes(string.encode('utf-8'))

    def finalize(self):
        buffer = np.frombuffer(self.buffer, dtype=np.uint8)
        buffer = np.unpackbits(buffer)
        data = self.data[..., -1].T
        h, w = data.shape
        data = data.reshape((-1,))
        data[:] = 0xff
        buf_len = buffer.shape[0]
        data[:buf_len] = 0xfe
        data[:buf_len] = np.bitwise_or(data[:buf_len], buffer)
        data = data.reshape((h, w)).T
        self.data[..., -1] = data


def serialize_metadata(metadata: PngInfo) -> bytes:
    # Extract metadata from PNG chunks
    data = {
        k: v
        for k, v in [
            data[1]
            .decode("latin-1" if data[0] == b"tEXt" else "utf-8")
            .split("\x00" if data[0] == b"tEXt" else "\x00\x00\x00\x00\x00")
            for data in metadata.chunks
            if data[0] == b"tEXt" or data[0] == b"iTXt"
        ]
    }
    # Save space by getting rid of reduntant metadata (Title is static)
    if "Title" in data:
        del data["Title"]
    # Encode and compress data using gzip
    data_encoded = json.dumps(data)
    return gzip.compress(bytes(data_encoded, "utf-8"))


def inject_data(image: Image.Image, data: PngInfo) -> Image.Image:
    rgb = np.array(image.convert('RGB'))
    image = image.convert('RGBA')
    w, h = image.size
    pixels = np.array(image)
    injector = LSBInjector(pixels)
    injector.put_string("stealth_pngcomp")
    data = serialize_metadata(data)
    injector.put_32bit_integer(len(data) * 8)
    injector.put_bytes(data)
    fec_data = fec_encode(bytearray(rgb.tobytes()), w, h)
    injector.put_32bit_integer(len(fec_data) * 8)
    injector.put_bytes(fec_data)
    injector.finalize()
    return Image.fromarray(injector.data)
+88 −0
Original line number Diff line number Diff line
import json
import os
from dataclasses import dataclass
from typing import Optional, Union

from PIL import Image
from PIL.PngImagePlugin import PngInfo

from .extract import ImageLsbDataExtractor
from .inject import inject_data
from ...data import load_image, ImageTyping


@dataclass
class NAIMetadata:
    software: str
    source: str
    title: Optional[str] = None
    generation_time: Optional[float] = None
    description: Optional[str] = None
    parameters: Optional[dict] = None

    @property
    def pnginfo(self) -> PngInfo:
        info = PngInfo()
        info.add_text('Software', self.software)
        info.add_text('Source', self.source)
        if self.title is not None:
            info.add_text('Title', self.title)
        if self.generation_time is not None:
            info.add_text('Generation time', json.dumps(self.generation_time)),
        if self.description is not None:
            info.add_text('Description', self.description)
        if self.parameters is not None:
            info.add_text('Comment', json.dumps(self.parameters))
        return info


def _get_naimeta_raw(image: ImageTyping) -> dict:
    image = load_image(image, force_background=None, mode=None)
    try:
        return ImageLsbDataExtractor().extract_data(image)
    except (ValueError, json.JSONDecodeError):
        return image.info or {}


def get_naimeta_from_image(image: ImageTyping) -> Optional[NAIMetadata]:
    data = _get_naimeta_raw(image)
    if data.get('Software') and data.get('Source'):
        return NAIMetadata(
            software=data['Software'],
            source=data['Source'],
            title=data.get('Title'),
            generation_time=float(data['Generation time']) if data.get('Generation time') else None,
            description=data.get('Description'),
            parameters=json.loads(data['Comment']) if data.get('Comment') else None,
        )
    else:
        return None


def _get_pnginfo(metadata: Union[NAIMetadata, PngInfo]) -> PngInfo:
    if isinstance(metadata, NAIMetadata):
        pnginfo = metadata.pnginfo
    elif isinstance(metadata, PngInfo):
        pnginfo = metadata
    else:
        raise TypeError(f'Unknown metadata type for NAI - {metadata!r}.')
    return pnginfo


def add_naimeta_to_image(image: ImageTyping, metadata: Union[NAIMetadata, PngInfo]) -> Image.Image:
    pnginfo = _get_pnginfo(metadata)
    image = load_image(image, mode=None, force_background=None)
    return inject_data(image, data=pnginfo)


def save_image_with_naimeta(image: ImageTyping, dst_file: Union[str, os.PathLike],
                            metadata: Union[NAIMetadata, PngInfo],
                            add_lsb_meta: bool = True, save_pnginfo: bool = True, **kwargs) -> Image.Image:
    pnginfo = _get_pnginfo(metadata)
    image = load_image(image, mode=None, force_background=None)
    if add_lsb_meta:
        image = add_naimeta_to_image(image, metadata=pnginfo)
    if save_pnginfo:
        kwargs['pnginfo'] = pnginfo
    image.save(dst_file, **kwargs)
    return image
Loading