Commit d3dba8fd authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): embedding inverse

parent 8d90e3cc
Loading
Loading
Loading
Loading
+26 −1
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ Overview:
    `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ .
"""
from functools import lru_cache
from typing import List, Tuple
from typing import List, Tuple, Optional

import numpy as np
import onnxruntime
@@ -242,3 +242,28 @@ def get_wd14_tags(
            'prediction': preds[0].astype(np.float32),
        }
    )


@lru_cache()
def _inv_data(model_name: str = _DEFAULT_MODEL_NAME):
    data = np.load(hf_hub_download(
        repo_id='deepghs/wd14_tagger_with_embeddings',
        repo_type='model',
        filename=f'{MODEL_NAMES[model_name]}/inv.npz',
    ))
    return data['best_epi'], data['inv_weights'], data['bias']


def _inv_sigmoid(x):
    return np.log(x) - np.log(1 - x)


def inv_wd14_by_predictions(predictions: np.ndarray, model_name: str = _DEFAULT_MODEL_NAME,
                            epi: Optional[float] = None, norm: bool = False) -> np.ndarray:
    best_epi, inv_weights, bias = _inv_data(model_name)
    eps = 10 ** -(epi if epi is not None else best_epi)
    pred_input = np.clip(predictions, a_min=eps, a_max=1.0 - eps)
    inv_emb_output = (_inv_sigmoid(pred_input) - bias) @ inv_weights
    if norm:
        inv_emb_output = inv_emb_output / np.linalg.norm(inv_emb_output, axis=-1)[..., None]
    return inv_emb_output