Commit f0572c40 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): optimize wd14 code

parent adeef56b
Loading
Loading
Loading
Loading
+3 −11
Original line number Diff line number Diff line
@@ -10,7 +10,6 @@ import cv2
import huggingface_hub
import numpy as np
import pandas as pd
from PIL import Image

from ..data import load_image, ImageTyping
from ..utils import open_onnx_model
@@ -120,19 +119,12 @@ def get_wd14_tags(image: ImageTyping, model_name: str = "ConvNextV2",
        >>> chars
        {'hu_tao_(genshin_impact)': 0.9262397289276123, 'boo_tao_(genshin_impact)': 0.942080020904541}
    """
    image = load_image(image)
    model = _get_wd14_model(model_name)
    _, height, width, _ = model.get_inputs()[0].shape

    # Alpha to white
    image = image.convert("RGBA")
    new_image = Image.new("RGBA", image.size, "WHITE")
    new_image.paste(image, mask=image)
    image = new_image.convert("RGB")
    image = np.asarray(image)

    # PIL RGB to OpenCV BGR
    image = image[:, :, ::-1]
    # Load image, PIL RGB to OpenCV BGR
    image = load_image(image, mode='RGB', force_background='white')
    image = np.asarray(image)[:, :, ::-1]

    image = make_square(image, height)
    image = smart_resize(image, height)