Commit 0ea6021e authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): try fix okay, ci skip

parent 0a745476
Loading
Loading
Loading
Loading
+45 −1
Original line number Diff line number Diff line
@@ -3,8 +3,9 @@ Overview:
    Tagging utils based on wd14 v2, inspired by
    `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ .
"""
import json
from functools import lru_cache
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Dict, Union

import numpy as np
import onnxruntime
@@ -267,3 +268,46 @@ def inv_wd14_by_predictions(predictions: np.ndarray, model_name: str = _DEFAULT_
    if norm:
        inv_emb_output = inv_emb_output / np.linalg.norm(inv_emb_output, axis=-1)[..., None]
    return inv_emb_output


@lru_cache()
def _wd14_alias_map(model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[Dict[str, Tuple[str, int]], int]:
    df_tags = pd.read_csv(hf_hub_download(
        repo_id='deepghs/wd14_tagger_with_embeddings',
        repo_type='model',
        filename=f'{MODEL_NAMES[model_name]}/tags_info.csv',
    ))

    from .match import _cached_singular_form, _cache_plural_form

    retval = {}
    for i, item in enumerate(df_tags.to_dict('records')):
        tags = sorted({item['name'], *json.loads(item['aliases'])})
        for tag in tags:
            forms = sorted({tag, _cached_singular_form(tag), _cache_plural_form(tag)})
            for tag_form in forms:
                retval[tag_form] = (item['name'], i)

    return retval, len(df_tags)


def get_wd14_pred_mask_by_tags(tags: Union[List[str], Dict[str, float]],
                               model_name: str = _DEFAULT_MODEL_NAME) -> np.ndarray:
    from .format import add_underline

    if isinstance(tags, (list, tuple)):
        tags = {tag: 1.0 for tag in tags}

    mapping, width = _wd14_alias_map(model_name)
    arr = np.zeros((width,), dtype=np.float32)
    # arr = np.random.randn(width).astype(np.float32) + 0.5 * 0.25
    # arr = np.clip(arr, a_min=0.0, a_max=1.0)
    for tag, value in tags.items():
        origin_tag, tag = tag, add_underline(tag)
        if tag not in mapping:
            raise ValueError(f'Unknown tag {origin_tag!r}.')

        real_tag_name, position = mapping[tag]
        arr[position] = value

    return arr
+22 −0
Original line number Diff line number Diff line
@@ -63,6 +63,27 @@ def _make_inverse(model_name, dst_dir: str, onnx_model_file: Optional[str] = Non
    def inv_sigmoid(x):
        return np.log(x) - np.log(1 - x)

    def is_inv_safe(v_epi):
        eps = 10 ** -v_epi
        p = np.concatenate([
            np.ones(10).astype(np.float32),
            np.zeros(10).astype(np.float32),
        ])
        x = np.clip(p, a_min=eps, a_max=1.0 - eps)
        y = inv_sigmoid(x)
        return not bool(np.isnan(y).any() or np.isinf(y).any())

    def get_max_safe_epi(tol=1e-6):
        sl, sr = 1.0, 30.0
        while sl < sr - tol:
            sm = (sl + sr) / 2
            if is_inv_safe(sm):
                sl = sm
            else:
                sr = sm

        return sl

    origin = np.load(hf_hub_download(
        repo_id='deepghs/wd14_tagger_inversion',
        repo_type='dataset',
@@ -71,6 +92,7 @@ def _make_inverse(model_name, dst_dir: str, onnx_model_file: Optional[str] = Non
    predictions = origin['preds']
    embeddings = origin['embs']

    right = min(right, get_max_safe_epi())
    records = []
    for r in range(rounds):
        xs, ys = [], []