Commit c4685afd authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): no tags, ci skip

parent d3dba8fd
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -24,3 +24,4 @@ natsort
tabulate
hfmirror>=0.0.7
tabulate
git+https://github.com/deepghs/waifuc.git@main#egg=waifuc
 No newline at end of file

zoo/wd14/tags.py

0 → 100644
+99 −0
Original line number Diff line number Diff line
import json
import logging
from functools import lru_cache
from typing import List, Set

import pandas as pd
from huggingface_hub import hf_hub_download
from tqdm import tqdm
from waifuc.source import DanbooruSource

from imgutils.tagging.wd14 import MODEL_NAMES, LABEL_FILENAME


@lru_cache()
def _db_session():
    s = DanbooruSource(['1girl'])
    s._prune_session()
    return s.session


@lru_cache(maxsize=65536)
def _get_tag_by_id(tag_id: int):
    session = _db_session()
    return session.get(f'https://danbooru.donmai.us/tags/{tag_id}.json').json()


@lru_cache(maxsize=125536)
def _get_tag_by_name(tag_name: str):
    session = _db_session()
    vs = session.get(
        f'https://danbooru.donmai.us/tags.json',
        params={'search[name]': tag_name}
    ).json()
    return vs[0] if vs else None


@lru_cache(maxsize=65536)
def _simple_search_related_tags(tag: str) -> List[str]:
    session = _db_session()
    tags = []
    for item in session.get(
            'https://danbooru.donmai.us/tag_aliases.json',
            params={
                'search[name_matches]': tag,
            }
    ).json():
        if item['consequent_name'] == tag:
            tags.append(item['antecedent_name'])

    return tags


@lru_cache(maxsize=65536)
def _search_related_tags(tag: str, model_name: str = 'ConvNext') -> List[str]:
    existing_names = _tags_name_set(model_name)
    tags = [tag]
    i = 0
    while i < len(tags):
        append_tags = _simple_search_related_tags(tags[i])
        for tag_ in append_tags:
            if tag_ not in tags and tag_ not in existing_names:
                tags.append(tag_)

        i += 1

    return tags


@lru_cache()
def _tags_list(model_name) -> pd.DataFrame:
    return pd.read_csv(hf_hub_download(MODEL_NAMES[model_name], LABEL_FILENAME))


@lru_cache()
def _tags_name_set(model_name) -> Set[str]:
    return set(_tags_list(model_name)['name'])


def _make_tag_info(model_name='ConvNext') -> pd.DataFrame:
    df = _tags_list(model_name)
    records = []
    for item in tqdm(df.to_dict('records')):
        if item['category'] != 9:
            tag_info = _get_tag_by_id(item['tag_id'])
            item['count'] = tag_info['post_count']
            aliases = _search_related_tags(item['name'], model_name)
            logging.info(f'Aliases {aliases!r} --> {item["name"]!r}')
            item['aliases'] = json.dumps(aliases)
        else:
            item['aliases'] = json.dumps([item['name']])
        records.append(item)

    df_records = pd.DataFrame(records)
    return df_records


if __name__ == "__main__":
    df = _make_tag_info()
    df.to_csv('test_tags_info.csv', index=False)