Loading imgutils/tagging/wd14.py +45 −1 Original line number Diff line number Diff line Loading @@ -3,8 +3,9 @@ Overview: Tagging utils based on wd14 v2, inspired by `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . """ import json from functools import lru_cache from typing import List, Tuple, Optional from typing import List, Tuple, Optional, Dict, Union import numpy as np import onnxruntime Loading Loading @@ -267,3 +268,46 @@ def inv_wd14_by_predictions(predictions: np.ndarray, model_name: str = _DEFAULT_ if norm: inv_emb_output = inv_emb_output / np.linalg.norm(inv_emb_output, axis=-1)[..., None] return inv_emb_output @lru_cache() def _wd14_alias_map(model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[Dict[str, Tuple[str, int]], int]: df_tags = pd.read_csv(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', repo_type='model', filename=f'{MODEL_NAMES[model_name]}/tags_info.csv', )) from .match import _cached_singular_form, _cache_plural_form retval = {} for i, item in enumerate(df_tags.to_dict('records')): tags = sorted({item['name'], *json.loads(item['aliases'])}) for tag in tags: forms = sorted({tag, _cached_singular_form(tag), _cache_plural_form(tag)}) for tag_form in forms: retval[tag_form] = (item['name'], i) return retval, len(df_tags) def get_wd14_pred_mask_by_tags(tags: Union[List[str], Dict[str, float]], model_name: str = _DEFAULT_MODEL_NAME) -> np.ndarray: from .format import add_underline if isinstance(tags, (list, tuple)): tags = {tag: 1.0 for tag in tags} mapping, width = _wd14_alias_map(model_name) arr = np.zeros((width,), dtype=np.float32) # arr = np.random.randn(width).astype(np.float32) + 0.5 * 0.25 # arr = np.clip(arr, a_min=0.0, a_max=1.0) for tag, value in tags.items(): origin_tag, tag = tag, add_underline(tag) if tag not in mapping: raise ValueError(f'Unknown tag {origin_tag!r}.') real_tag_name, position = mapping[tag] arr[position] = value return arr zoo/wd14/inv.py +22 −0 Original line number Diff line number Diff line Loading @@ -63,6 +63,27 @@ def _make_inverse(model_name, dst_dir: str, onnx_model_file: Optional[str] = Non def inv_sigmoid(x): return np.log(x) - np.log(1 - x) def is_inv_safe(v_epi): eps = 10 ** -v_epi p = np.concatenate([ np.ones(10).astype(np.float32), np.zeros(10).astype(np.float32), ]) x = np.clip(p, a_min=eps, a_max=1.0 - eps) y = inv_sigmoid(x) return not bool(np.isnan(y).any() or np.isinf(y).any()) def get_max_safe_epi(tol=1e-6): sl, sr = 1.0, 30.0 while sl < sr - tol: sm = (sl + sr) / 2 if is_inv_safe(sm): sl = sm else: sr = sm return sl origin = np.load(hf_hub_download( repo_id='deepghs/wd14_tagger_inversion', repo_type='dataset', Loading @@ -71,6 +92,7 @@ def _make_inverse(model_name, dst_dir: str, onnx_model_file: Optional[str] = Non predictions = origin['preds'] embeddings = origin['embs'] right = min(right, get_max_safe_epi()) records = [] for r in range(rounds): xs, ys = [], [] Loading Loading
imgutils/tagging/wd14.py +45 −1 Original line number Diff line number Diff line Loading @@ -3,8 +3,9 @@ Overview: Tagging utils based on wd14 v2, inspired by `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_ . """ import json from functools import lru_cache from typing import List, Tuple, Optional from typing import List, Tuple, Optional, Dict, Union import numpy as np import onnxruntime Loading Loading @@ -267,3 +268,46 @@ def inv_wd14_by_predictions(predictions: np.ndarray, model_name: str = _DEFAULT_ if norm: inv_emb_output = inv_emb_output / np.linalg.norm(inv_emb_output, axis=-1)[..., None] return inv_emb_output @lru_cache() def _wd14_alias_map(model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[Dict[str, Tuple[str, int]], int]: df_tags = pd.read_csv(hf_hub_download( repo_id='deepghs/wd14_tagger_with_embeddings', repo_type='model', filename=f'{MODEL_NAMES[model_name]}/tags_info.csv', )) from .match import _cached_singular_form, _cache_plural_form retval = {} for i, item in enumerate(df_tags.to_dict('records')): tags = sorted({item['name'], *json.loads(item['aliases'])}) for tag in tags: forms = sorted({tag, _cached_singular_form(tag), _cache_plural_form(tag)}) for tag_form in forms: retval[tag_form] = (item['name'], i) return retval, len(df_tags) def get_wd14_pred_mask_by_tags(tags: Union[List[str], Dict[str, float]], model_name: str = _DEFAULT_MODEL_NAME) -> np.ndarray: from .format import add_underline if isinstance(tags, (list, tuple)): tags = {tag: 1.0 for tag in tags} mapping, width = _wd14_alias_map(model_name) arr = np.zeros((width,), dtype=np.float32) # arr = np.random.randn(width).astype(np.float32) + 0.5 * 0.25 # arr = np.clip(arr, a_min=0.0, a_max=1.0) for tag, value in tags.items(): origin_tag, tag = tag, add_underline(tag) if tag not in mapping: raise ValueError(f'Unknown tag {origin_tag!r}.') real_tag_name, position = mapping[tag] arr[position] = value return arr
zoo/wd14/inv.py +22 −0 Original line number Diff line number Diff line Loading @@ -63,6 +63,27 @@ def _make_inverse(model_name, dst_dir: str, onnx_model_file: Optional[str] = Non def inv_sigmoid(x): return np.log(x) - np.log(1 - x) def is_inv_safe(v_epi): eps = 10 ** -v_epi p = np.concatenate([ np.ones(10).astype(np.float32), np.zeros(10).astype(np.float32), ]) x = np.clip(p, a_min=eps, a_max=1.0 - eps) y = inv_sigmoid(x) return not bool(np.isnan(y).any() or np.isinf(y).any()) def get_max_safe_epi(tol=1e-6): sl, sr = 1.0, 30.0 while sl < sr - tol: sm = (sl + sr) / 2 if is_inv_safe(sm): sl = sm else: sr = sm return sl origin = np.load(hf_hub_download( repo_id='deepghs/wd14_tagger_inversion', repo_type='dataset', Loading @@ -71,6 +92,7 @@ def _make_inverse(model_name, dst_dir: str, onnx_model_file: Optional[str] = Non predictions = origin['preds'] embeddings = origin['embs'] right = min(right, get_max_safe_epi()) records = [] for r in range(rounds): xs, ys = [], [] Loading