Commit 7540259c authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add model metadata reader and writer

parent cc802fc1
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -44,6 +44,7 @@ jobs:
          sudo apt-get install -y make wget curl cloc graphviz pandoc
          dot -V
          python -m pip install -r requirements.txt
          python -m pip install -r requirements-model.txt
          python -m pip install -r requirements-doc.txt
      - name: Prepare dataset
        uses: nick-fields/retry@v2
+1 −0
Original line number Diff line number Diff line
@@ -86,6 +86,7 @@ jobs:
          python -m pip install --upgrade pip
          pip install --upgrade flake8 setuptools wheel twine
          pip install -r requirements.txt
          pip install -r requirements-model.txt
          pip install -r requirements-test.txt
      - name: Test the basic environment
        shell: bash
+1 −0
Original line number Diff line number Diff line
@@ -3,3 +3,4 @@ Overview:
    Utilities for dealing with data from `AUTOMATIC1111/stable-diffusion-webui <https://github.com/AUTOMATIC1111/stable-diffusion-webui>`_.
"""
from .metadata import parse_sdmeta_from_text, get_sdmeta_from_image, SDMetaData
from .model import read_metadata, save_with_metadata

imgutils/sd/model.py

0 → 100644
+40 −0
Original line number Diff line number Diff line
from typing import Dict

try:
    import torch
except (ImportError, ModuleNotFoundError):  # pragma: no cover
    torch = None

try:
    import safetensors.torch
except (ImportError, ModuleNotFoundError):  # pragma: no cover
    safetensors = None


def _check_env():
    if not safetensors:
        raise EnvironmentError(
            'Safetensors not installed. Please use "pip install dghs-imgutils[model]".')  # pragma: no cover
    if not torch:
        raise EnvironmentError(
            'Torch not installed. Please use "pip install dghs-imgutils[model]".')  # pragma: no cover


def read_metadata(model_file: str) -> Dict[str, str]:
    _check_env()
    with safetensors.safe_open(model_file, 'pt') as f:
        return f.metadata()


def save_with_metadata(src_model_file: str, dst_model_file: str, metadata: Dict[str, str], clear: bool = False):
    _check_env()
    with safetensors.safe_open(src_model_file, framework='pt') as f:
        if clear:
            new_metadata = {**(metadata or {})}
        else:
            new_metadata = {**f.metadata(), **(metadata or {})}
        safetensors.torch.save_file(
            tensors={key: f.get_tensor(key) for key in f.keys()},
            filename=dst_model_file,
            metadata=new_metadata,
        )

requirements-model.txt

0 → 100644
+2 −0
Original line number Diff line number Diff line
torch
safetensors
 No newline at end of file
Loading