Commit c9ef1051 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add pixai docs, ci skip, fix some bugs

parent db8a1bae
Loading
Loading
Loading
Loading
+21 −4
Original line number Diff line number Diff line
@@ -23,7 +23,8 @@ Overview:
"""

import json
from typing import Union, Dict, Any, Tuple
from collections import defaultdict
from typing import Union, Dict, Any, Tuple, List

import pandas as pd
from hbutils.design import SingletonMark
@@ -81,7 +82,7 @@ def _open_onnx_model(model_name: str):


@ts_lru_cache()
def _open_tags(model_name: str) -> pd.DataFrame:
def _open_tags(model_name: str) -> Tuple[pd.DataFrame, Dict[str, List[str]]]:
    """
    Load the tag metadata from Hugging Face Hub with caching.

@@ -95,11 +96,18 @@ def _open_tags(model_name: str) -> pd.DataFrame:
    :return: DataFrame containing tag information with columns like 'name', 'category'
    :rtype: pd.DataFrame
    """
    return pd.read_csv(hf_hub_download(
    df_tags = pd.read_csv(hf_hub_download(
        repo_id=_get_repo_id(model_name),
        repo_type='model',
        filename='selected_tags.csv',
    ))
    d_ips = {}
    if 'ips' in df_tags:
        df_tags['ips'] = df_tags['ips'].map(json.loads)
        for name, ips in zip(df_tags['name'], df_tags['ips']):
            if ips:
                d_ips[name] = ips
    return df_tags, d_ips


@ts_lru_cache()
@@ -250,7 +258,7 @@ def get_pixai_tags(image: ImageTyping, model_name: str = 'v0.9',
        >>> import numpy as np
        >>> normalized_embedding = embedding / np.linalg.norm(embedding)
    """
    df_tags = _open_tags(model_name=model_name)
    df_tags, d_ips = _open_tags(model_name=model_name)
    values = _raw_predict(image, model_name=model_name)
    prediction = values['prediction']
    tags = {}
@@ -288,4 +296,13 @@ def get_pixai_tags(image: ImageTyping, model_name: str = 'v0.9',
        tags.update(cate_tags)

    values['tag'] = tags
    ip_mapping, ip_counts = {}, defaultdict(lambda: 0)
    if 'ips' in df_tags.columns:
        for tag, _ in tags.items():
            if tag in d_ips[tag]:
                ip_mapping[tag] = d_ips[tag]
                for ip_name in d_ips[tag]:
                    ip_counts[ip_name] += 1
    values['ips_mapping'] = ip_mapping
    values['ips'] = [x for x, _ in sorted(ip_counts.items(), key=lambda x: (-x[1], x[0]))]
    return vreplace(fmt, values)
+1 −1
Original line number Diff line number Diff line
@@ -60,7 +60,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
                    exts.append([])
            else:
                exts.append([])
        df_tags['ips'] = exts
        df_tags['ips'] = list(map(json.dumps, exts))
        logging.info(f'Tags:\n{df_tags}')
        df_tags.to_csv(os.path.join(upload_dir, 'selected_tags.csv'), index=False)