Commit 052404f4 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add tag lazy mode, ci skip

parent 5d632f24
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -3,6 +3,11 @@ name: Sync WD14 Models
on:
  #  push:
  workflow_dispatch:
    inputs:
      tag_lazy_mode:
        description: 'Enable Tag Lazy Mode'
        type: boolean
        default: false
  schedule:
    - cron: '30 18 * * 0'

@@ -36,6 +41,16 @@ jobs:
          if [ -f requirements-test.txt ]; then pip install -r requirements-test.txt; fi
          if [ -f requirements-test.txt ]; then pip install -r requirements-zoo.txt; fi
          pip install --upgrade build
      - name: Enable Tag Lazy Mode
        if: ${{ (github.event.inputs.drop_multi || 'false') == 'true' }}
        shell: bash
        run: |
          echo 'TAG_LAZY_MODE=1' >> $GITHUB_ENV
      - name: Disable Tag Lazy Mode
        if: ${{ (github.event.inputs.drop_multi || 'false') == 'false' }}
        shell: bash
        run: |
          echo 'TAG_LAZY_MODE=' >> $GITHUB_ENV
      - name: Sync Models
        env:
          HF_TOKEN: ${{ secrets.HF_TOKEN }}
+5 −3
Original line number Diff line number Diff line
@@ -51,7 +51,7 @@ _FC_NODE_PREFIXES_FOR_V3 = {
}


def sync():
def sync(tag_lazy_mode: bool = False):
    hf_fs = get_hf_fs()

    import onnxruntime
@@ -134,7 +134,7 @@ def sync():
            else:
                invertible = False

            df = _make_tag_info(model_name)
            df = _make_tag_info(model_name, lazy_mode=tag_lazy_mode)
            assert len(df) == _get_model_tags_length(model_name)
            df.to_csv(os.path.join(td, MODEL_NAMES[model_name], 'tags_info.csv'), index=False)

@@ -176,4 +176,6 @@ def sync():


if __name__ == '__main__':
    sync()
    sync(
        tag_lazy_mode=bool(os.environ.get('TAG_LAZY_MODE')),
    )
+35 −11
Original line number Diff line number Diff line
import json
import logging
import os
from functools import lru_cache
from typing import List, Set
from typing import List, Set, Dict

import numpy as np
import pandas as pd
from ditk import logging
from hfutils.operate import get_hf_fs
from huggingface_hub import hf_hub_download
from tqdm import tqdm
from waifuc.source import DanbooruSource
@@ -79,7 +82,24 @@ def _tags_name_set(model_name) -> Set[str]:
    return set(_tags_list(model_name)['name'])


def _make_tag_info(model_name='ConvNext') -> pd.DataFrame:
@lru_cache()
def _load_all_previous(repository: str = 'deepghs/wd14_tagger_with_embeddings') -> Dict[int, dict]:
    hf_fs = get_hf_fs()
    d = {}
    for path in hf_fs.glob(f'{repository}/**/tags_info.csv'):
        relpath = os.path.relpath(path, f'{repository}')
        df = pd.read_csv(hf_hub_download(
            repo_id=repository,
            repo_type='model',
            filename=relpath,
        )).replace(np.nan, None)
        for item in df.to_dict('records'):
            if item['tag_id'] not in d:
                d[item['tag_id']] = item
    return d


def _make_tag_info(model_name='ConvNext', lazy_mode: bool = False) -> pd.DataFrame:
    with open(hf_hub_download(
            repo_id='deepghs/tags_meta',
            repo_type='dataset',
@@ -88,8 +108,12 @@ def _make_tag_info(model_name='ConvNext') -> pd.DataFrame:
        attire_tags = json.load(f)

    df = _tags_list(model_name)
    d = _load_all_previous()
    records = []
    for item in tqdm(df.to_dict('records')):
        if lazy_mode and item['tag_id'] in d:
            item = d[item['tag_id']]
        else:
            if item['category'] != 9:
                tag_info = _get_tag_by_id(item['tag_id'])
                item['count'] = tag_info['post_count']