Commit 719e6b01 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): remove pred inversion plan from code, mentioned in #95, ci skip

parent 18df14cd
Loading
Loading
Loading
Loading
+1 −70
Original line number Diff line number Diff line
@@ -3,9 +3,8 @@ Overview:
    Tagging utils based on wd14 v2, inspired by
    `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ .
"""
import json
from functools import lru_cache
from typing import List, Tuple, Optional, Dict, Union
from typing import List, Tuple, Dict

import numpy as np
import onnxruntime
@@ -243,71 +242,3 @@ def get_wd14_tags(
            'prediction': preds[0].astype(np.float32),
        }
    )


@lru_cache()
def _inv_data(model_name: str = _DEFAULT_MODEL_NAME):
    data = np.load(hf_hub_download(
        repo_id='deepghs/wd14_tagger_with_embeddings',
        repo_type='model',
        filename=f'{MODEL_NAMES[model_name]}/inv.npz',
    ))
    return data['best_epi'], data['inv_weights'], data['bias']


def _inv_sigmoid(x):
    return np.log(x) - np.log(1 - x)


def inv_wd14_by_predictions(predictions: np.ndarray, model_name: str = _DEFAULT_MODEL_NAME,
                            epi: Optional[float] = None, norm: bool = False) -> np.ndarray:
    best_epi, inv_weights, bias = _inv_data(model_name)
    eps = 10 ** -(epi if epi is not None else best_epi)
    pred_input = np.clip(predictions, a_min=eps, a_max=1.0 - eps)
    inv_emb_output = (_inv_sigmoid(pred_input) - bias) @ inv_weights
    if norm:
        inv_emb_output = inv_emb_output / np.linalg.norm(inv_emb_output, axis=-1)[..., None]
    return inv_emb_output


@lru_cache()
def _wd14_alias_map(model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[Dict[str, Tuple[str, int]], int]:
    df_tags = pd.read_csv(hf_hub_download(
        repo_id='deepghs/wd14_tagger_with_embeddings',
        repo_type='model',
        filename=f'{MODEL_NAMES[model_name]}/tags_info.csv',
    ))

    from .match import _cached_singular_form, _cache_plural_form

    retval = {}
    for i, item in enumerate(df_tags.to_dict('records')):
        tags = sorted({item['name'], *json.loads(item['aliases'])})
        for tag in tags:
            forms = sorted({tag, _cached_singular_form(tag), _cache_plural_form(tag)})
            for tag_form in forms:
                retval[tag_form] = (item['name'], i)

    return retval, len(df_tags)


def get_wd14_pred_mask_by_tags(tags: Union[List[str], Dict[str, float]],
                               model_name: str = _DEFAULT_MODEL_NAME) -> np.ndarray:
    from .format import add_underline

    if isinstance(tags, (list, tuple)):
        tags = {tag: 1.0 for tag in tags}

    mapping, width = _wd14_alias_map(model_name)
    arr = np.zeros((width,), dtype=np.float32)
    # arr = np.random.randn(width).astype(np.float32) + 0.5 * 0.25
    # arr = np.clip(arr, a_min=0.0, a_max=1.0)
    for tag, value in tags.items():
        origin_tag, tag = tag, add_underline(tag)
        if tag not in mapping:
            raise ValueError(f'Unknown tag {origin_tag!r}.')

        real_tag_name, position = mapping[tag]
        arr[position] = value

    return arr