Commit 7b282e1a authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add docs for new features

parent 7540259c
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -10,4 +10,5 @@ imgutils.sd
    :maxdepth: 3

    metadata
    model
+22 −0
Original line number Diff line number Diff line
imgutils.sd.model
====================================

.. currentmodule:: imgutils.sd.model

.. automodule:: imgutils.sd.model


read_metadata
------------------------------------------

.. autofunction:: read_metadata



save_with_metadata
------------------------------------------

.. autofunction:: save_with_metadata


+33 −0
Original line number Diff line number Diff line
"""
Overview:
    A utility for reading and writing metadata from/to model files in the A41 WebUI format.

    .. note::
        ``torch`` and ``safetensors`` are required by this model.
        Please install them with ``pip install dghs-imgutils[model]`` before start using this part.
"""

from typing import Dict

try:
@@ -12,6 +21,10 @@ except (ImportError, ModuleNotFoundError): # pragma: no cover


def _check_env():
    """
    Checks if the required dependencies (Safetensors and Torch) are installed.
    Raises EnvironmentError if they are not installed.
    """
    if not safetensors:
        raise EnvironmentError(
            'Safetensors not installed. Please use "pip install dghs-imgutils[model]".')  # pragma: no cover
@@ -21,12 +34,32 @@ def _check_env():


def read_metadata(model_file: str) -> Dict[str, str]:
    """
    Reads metadata from a model file and returns it as a dictionary.

    :param model_file: The path to the model file.
    :type model_file: str
    :return: The metadata extracted from the model file.
    :rtype: Dict[str, str]
    """
    _check_env()
    with safetensors.safe_open(model_file, 'pt') as f:
        return f.metadata()


def save_with_metadata(src_model_file: str, dst_model_file: str, metadata: Dict[str, str], clear: bool = False):
    """
    Saves a model file with metadata. Optionally, existing metadata can be cleared before adding new metadata.

    :param src_model_file: The path to the source model file.
    :type src_model_file: str
    :param dst_model_file: The path to save the new model file.
    :type dst_model_file: str
    :param metadata: The metadata to add to the model file.
    :type metadata: Dict[str, str]
    :param clear: Whether to clear existing metadata before adding new metadata. Default is False.
    :type clear: bool
    """
    _check_env()
    with safetensors.safe_open(src_model_file, framework='pt') as f:
        if clear: