Loading imgutils/tagging/wd14.py +1 −70 Original line number Diff line number Diff line Loading @@ -3,9 +3,8 @@ Overview: Tagging utils based on wd14 v2, inspired by `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . """ import json from functools import lru_cache from typing import List, Tuple, Optional, Dict, Union from typing import List, Tuple, Dict import numpy as np import onnxruntime Loading Loading @@ -243,71 +242,3 @@ def get_wd14_tags( 'prediction': preds[0].astype(np.float32), } ) @lru_cache() def _inv_data(model_name: str = _DEFAULT_MODEL_NAME): data = np.load(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', repo_type='model', filename=f'{MODEL_NAMES[model_name]}/inv.npz', )) return data['best_epi'], data['inv_weights'], data['bias'] def _inv_sigmoid(x): return np.log(x) - np.log(1 - x) def inv_wd14_by_predictions(predictions: np.ndarray, model_name: str = _DEFAULT_MODEL_NAME, epi: Optional[float] = None, norm: bool = False) -> np.ndarray: best_epi, inv_weights, bias = _inv_data(model_name) eps = 10 ** -(epi if epi is not None else best_epi) pred_input = np.clip(predictions, a_min=eps, a_max=1.0 - eps) inv_emb_output = (_inv_sigmoid(pred_input) - bias) @ inv_weights if norm: inv_emb_output = inv_emb_output / np.linalg.norm(inv_emb_output, axis=-1)[..., None] return inv_emb_output @lru_cache() def _wd14_alias_map(model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[Dict[str, Tuple[str, int]], int]: df_tags = pd.read_csv(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', repo_type='model', filename=f'{MODEL_NAMES[model_name]}/tags_info.csv', )) from .match import _cached_singular_form, _cache_plural_form retval = {} for i, item in enumerate(df_tags.to_dict('records')): tags = sorted({item['name'], *json.loads(item['aliases'])}) for tag in tags: forms = sorted({tag, _cached_singular_form(tag), _cache_plural_form(tag)}) for tag_form in forms: retval[tag_form] = (item['name'], i) return retval, len(df_tags) def get_wd14_pred_mask_by_tags(tags: Union[List[str], Dict[str, float]], model_name: str = _DEFAULT_MODEL_NAME) -> np.ndarray: from .format import add_underline if isinstance(tags, (list, tuple)): tags = {tag: 1.0 for tag in tags} mapping, width = _wd14_alias_map(model_name) arr = np.zeros((width,), dtype=np.float32) # arr = np.random.randn(width).astype(np.float32) + 0.5 * 0.25 # arr = np.clip(arr, a_min=0.0, a_max=1.0) for tag, value in tags.items(): origin_tag, tag = tag, add_underline(tag) if tag not in mapping: raise ValueError(f'Unknown tag {origin_tag!r}.') real_tag_name, position = mapping[tag] arr[position] = value return arr Loading
imgutils/tagging/wd14.py +1 −70 Original line number Diff line number Diff line Loading @@ -3,9 +3,8 @@ Overview: Tagging utils based on wd14 v2, inspired by `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . """ import json from functools import lru_cache from typing import List, Tuple, Optional, Dict, Union from typing import List, Tuple, Dict import numpy as np import onnxruntime Loading Loading @@ -243,71 +242,3 @@ def get_wd14_tags( 'prediction': preds[0].astype(np.float32), } ) @lru_cache() def _inv_data(model_name: str = _DEFAULT_MODEL_NAME): data = np.load(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', repo_type='model', filename=f'{MODEL_NAMES[model_name]}/inv.npz', )) return data['best_epi'], data['inv_weights'], data['bias'] def _inv_sigmoid(x): return np.log(x) - np.log(1 - x) def inv_wd14_by_predictions(predictions: np.ndarray, model_name: str = _DEFAULT_MODEL_NAME, epi: Optional[float] = None, norm: bool = False) -> np.ndarray: best_epi, inv_weights, bias = _inv_data(model_name) eps = 10 ** -(epi if epi is not None else best_epi) pred_input = np.clip(predictions, a_min=eps, a_max=1.0 - eps) inv_emb_output = (_inv_sigmoid(pred_input) - bias) @ inv_weights if norm: inv_emb_output = inv_emb_output / np.linalg.norm(inv_emb_output, axis=-1)[..., None] return inv_emb_output @lru_cache() def _wd14_alias_map(model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[Dict[str, Tuple[str, int]], int]: df_tags = pd.read_csv(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', repo_type='model', filename=f'{MODEL_NAMES[model_name]}/tags_info.csv', )) from .match import _cached_singular_form, _cache_plural_form retval = {} for i, item in enumerate(df_tags.to_dict('records')): tags = sorted({item['name'], *json.loads(item['aliases'])}) for tag in tags: forms = sorted({tag, _cached_singular_form(tag), _cache_plural_form(tag)}) for tag_form in forms: retval[tag_form] = (item['name'], i) return retval, len(df_tags) def get_wd14_pred_mask_by_tags(tags: Union[List[str], Dict[str, float]], model_name: str = _DEFAULT_MODEL_NAME) -> np.ndarray: from .format import add_underline if isinstance(tags, (list, tuple)): tags = {tag: 1.0 for tag in tags} mapping, width = _wd14_alias_map(model_name) arr = np.zeros((width,), dtype=np.float32) # arr = np.random.randn(width).astype(np.float32) + 0.5 * 0.25 # arr = np.clip(arr, a_min=0.0, a_max=1.0) for tag, value in tags.items(): origin_tag, tag = tag, add_underline(tag) if tag not in mapping: raise ValueError(f'Unknown tag {origin_tag!r}.') real_tag_name, position = mapping[tag] arr[position] = value return arr