Commit e4e59aed authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): save all, ci skip

parent f95f4c9d
Loading
Loading
Loading
Loading
+44 −0
Original line number Diff line number Diff line
name: Sync WD14 Groups

on:
  #  push:
  workflow_dispatch:
  schedule:
    - cron: '30 18 * * 6'

jobs:
  sync:
    name: Sync Waifu2x ONNX
    runs-on: ${{ matrix.os }}
    strategy:
      fail-fast: false
      matrix:
        os:
          - 'ubuntu-latest'
        python-version:
          - '3.8'

    steps:
      - name: Checkout code
        uses: actions/checkout@v3
        with:
          fetch-depth: 20
      - name: Set up python ${{ matrix.python-version }}
        uses: actions/setup-python@v4
        with:
          python-version: ${{ matrix.python-version }}
      - name: Set up python dependences
        run: |
          pip install --upgrade pip
          pip install --upgrade flake8 setuptools wheel twine
          if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
          if [ -f requirements-build.txt ]; then pip install -r requirements-build.txt; fi
          if [ -f requirements-test.txt ]; then pip install -r requirements-test.txt; fi
          if [ -f requirements-test.txt ]; then pip install -r requirements-zoo.txt; fi
          pip install --upgrade build
      - name: Sync Models
        env:
          HF_TOKEN: ${{ secrets.HF_TOKEN }}
          GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }}
        run: |
          python -m zoo.wd14.tag_groups
+3 −2
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ from waifuc.utils import srequest

from zoo.wd14.tags import _db_session, _get_tag_by_name

logging.try_init_root(logging.INFO)
session = _db_session()


@@ -165,6 +166,7 @@ def _make_table(limit: Optional[int] = None):

    df_record = pd.DataFrame(records)
    df_record = df_record.replace(np.NaN, False)
    df_record = df_record.sort_values(by=['posts', 'id'], ascending=[False, True])

    groupx = []
    for group_name, (group_category, group_parent) in all_groups.items():
@@ -183,7 +185,7 @@ def sync(repository='deepghs/danbooru_tag_groups'):
    if not hf_client.repo_exists(repo_id=repository, repo_type='dataset'):
        hf_client.create_repo(repo_id=repository, repo_type='dataset', private=True)

    df_record, df_groups = _make_table()
    df_record, df_groups = _make_table(limit=60)
    with TemporaryDirectory() as td:
        df_record.to_csv(os.path.join(td, 'tags.csv'), index=False)
        df_groups.to_csv(os.path.join(td, 'groups.csv'), index=False)
@@ -198,5 +200,4 @@ def sync(repository='deepghs/danbooru_tag_groups'):


if __name__ == '__main__':
    logging.try_init_root(logging.INFO)
    sync()