Unverified Commit e9b46037 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #2 from deepghs/2dvit

add levit 2d model
parents 6d4798d4 9001e5ed
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@ jobs:
        python-version:
          - '3.8'
        model-name:
          - 'lpips'
          #          - 'lpips'
          - 'monochrome'

    steps:
+16 −2
Original line number Diff line number Diff line
from os import PathLike
from typing import Union, BinaryIO, List, Tuple
from typing import Union, BinaryIO, List, Tuple, Optional

from PIL import Image

__all__ = [
    'ImageTyping', 'load_image',
    'MultiImagesTyping', 'load_images',
    'add_background_for_rgba',
]


@@ -17,7 +18,7 @@ ImageTyping = Union[str, PathLike, bytes, bytearray, BinaryIO, Image.Image]
MultiImagesTyping = Union[ImageTyping, List[ImageTyping], Tuple[ImageTyping, ...]]


def load_image(image: ImageTyping, mode=None):
def load_image(image: ImageTyping, mode=None, force_background: Optional[str] = 'white'):
    if isinstance(image, (str, PathLike, bytes, bytearray, BinaryIO)) or _is_readable(image):
        image = Image.open(image)
    elif isinstance(image, Image.Image):
@@ -25,6 +26,9 @@ def load_image(image: ImageTyping, mode=None):
    else:
        raise TypeError(f'Unknown image type - {image!r}.')

    if force_background is not None:
        image = add_background_for_rgba(image, force_background)

    if mode is not None and image.mode != mode:
        image = image.convert(mode)

@@ -36,3 +40,13 @@ def load_images(images: MultiImagesTyping, mode=None) -> List[Image.Image]:
        images = [images]

    return [load_image(item, mode) for item in images]


def add_background_for_rgba(image: ImageTyping, background: str = 'white'):
    if not isinstance(image, Image.Image):
        image = load_image(image)
    image = image.convert("RGBA")
    new_image = Image.new("RGBA", image.size, background)
    new_image.paste(image, mask=image)
    image = new_image.convert("RGB")
    return image
+2 −0
Original line number Diff line number Diff line
from .deepdanbooru import get_deepdanbooru_tags
from .format import tags_to_text
from .wd14 import get_wd14_tags
+82 −0
Original line number Diff line number Diff line
from functools import lru_cache
from typing import Tuple, List

import numpy as np
import pandas as pd
from PIL import Image
from huggingface_hub import hf_hub_download

from ..data import ImageTyping, load_image
from ..utils import open_onnx_model


@lru_cache()
def _get_deepdanbooru_labels():
    csv_file = hf_hub_download('deepghs/imgutils-models', 'deepdanbooru/deepdanbooru_tags.csv')
    df = pd.read_csv(csv_file)

    tag_names = df["name"].tolist()
    tag_real_names = df['real_name'].tolist()
    rating_indexes = list(np.where(df["category"] == 9)[0])
    general_indexes = list(np.where(df["category"] == 0)[0])
    character_indexes = list(np.where(df["category"] == 4)[0])
    return tag_names, tag_real_names, \
           rating_indexes, general_indexes, character_indexes


@lru_cache()
def _get_deepdanbooru_model():
    return open_onnx_model(hf_hub_download(
        'deepghs/imgutils-models',
        'deepdanbooru/deepdanbooru.onnx',
    ))


def _image_preprocess(image: Image.Image) -> np.ndarray:
    o_width, o_height = image.size
    scale = 512.0 / max(o_width, o_height)
    f_width, f_height = map(lambda x: int(x * scale), (o_width, o_height))
    image = image.resize((f_width, f_height))

    data = np.asarray(image).astype(np.float32) / 255  # H x W x C
    height_pad_left = (512 - f_height) // 2
    height_pad_right = 512 - f_height - height_pad_left
    width_pad_left = (512 - f_width) // 2
    width_pad_right = 512 - f_width - width_pad_left
    data = np.pad(data, ((height_pad_left, height_pad_right), (width_pad_left, width_pad_right), (0, 0)),
                  mode='constant', constant_values=0.0)

    assert data.shape == (512, 512, 3), f'Shape (512, 512, 3) expected, but {data.shape!r} found.'
    return data.reshape((1, 512, 512, 3))  # B x H x W x C


def get_deepdanbooru_tags(image: ImageTyping, use_real_name: bool = False,
                          general_threshold: float = 0.5, character_threshold: float = 0.5):
    session = _get_deepdanbooru_model()
    _image_data = _image_preprocess(load_image(image, mode='RGB'))

    input_name = session.get_inputs()[0].name
    output_names = [output.name for output in session.get_outputs()]
    probs = session.run(output_names, {input_name: _image_data})[0]

    tag_names, tag_real_names, rating_indexes, general_indexes, character_indexes = _get_deepdanbooru_labels()
    labels: List[Tuple[str, float]] = list(zip(
        tag_real_names if use_real_name else tag_names,
        probs[0].astype(float).tolist(),
    ))

    # First 4 labels are actually ratings: pick one with argmax
    ratings_names = [labels[i] for i in rating_indexes]
    rating = dict(ratings_names)

    # Then we have general tags: pick anywhere prediction confidence > threshold
    general_names = [labels[i] for i in general_indexes]
    general_res = [x for x in general_names if x[1] > general_threshold]
    general_res = dict(general_res)

    # Everything else is characters: pick anywhere prediction confidence > threshold
    character_names = [labels[i] for i in character_indexes]
    character_res = [x for x in character_names if x[1] > character_threshold]
    character_res = dict(character_res)

    return rating, general_res, character_res
+24 −0
Original line number Diff line number Diff line
import re
from typing import Mapping

RE_SPECIAL = re.compile(r'([\\()])')


def tags_to_text(tags: Mapping[str, float],
                 use_spaces: bool = False, use_escape: bool = True,
                 include_ranks: bool = False, score_descend: bool = True) -> str:
    text_items = []
    tags_pairs = tags.items()
    if score_descend:
        tags_pairs = sorted(tags_pairs, key=lambda x: (-x[1], x[0]))
    for tag, score in tags_pairs:
        t_text = tag
        if use_spaces:
            t_text = t_text.replace('_', ' ')
        if use_escape:
            t_text = re.sub(RE_SPECIAL, r'\\\1', t_text)
        if include_ranks:
            t_text = f"({t_text}:{score:.3f})"
        text_items.append(t_text)

    return ', '.join(text_items)
Loading