Commit 1fa09eea authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): sync waifu2x model

parent 5644cfb7
Loading
Loading
Loading
Loading
+42 −0
Original line number Diff line number Diff line
name: Sync Waifu2x Models

on:
  push:
  workflow_dispatch:

jobs:
  sync:
    name: Sync Waifu2x ONNX
    runs-on: ${{ matrix.os }}
    strategy:
      fail-fast: false
      matrix:
        os:
          - 'ubuntu-latest'
        python-version:
          - '3.8'

    steps:
      - name: Checkout code
        uses: actions/checkout@v3
        with:
          fetch-depth: 20
      - name: Set up python ${{ matrix.python-version }}
        uses: actions/setup-python@v4
        with:
          python-version: ${{ matrix.python-version }}
      - name: Set up python dependences
        run: |
          pip install --upgrade pip
          pip install --upgrade flake8 setuptools wheel twine
          if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
          if [ -f requirements-build.txt ]; then pip install -r requirements-build.txt; fi
          if [ -f requirements-test.txt ]; then pip install -r requirements-test.txt; fi
          if [ -f requirements-test.txt ]; then pip install -r requirements-zoo.txt; fi
          pip install --upgrade build
      - name: Sync Models
        env:
          HF_TOKEN: ${{ secrets.HF_TOKEN }}
          GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }}
        run: |
          python -m zoo.waifu2x sync
+2 −1
Original line number Diff line number Diff line
@@ -22,3 +22,4 @@ controlnet_aux
lighttuner
natsort
tabulate
hfmirror>=0.0.7
 No newline at end of file
+0 −0

Empty file added.

+31 −0
Original line number Diff line number Diff line
from functools import partial

import click
from ditk import logging

from .sync import sync_to_huggingface
from ..utils import GLOBAL_CONTEXT_SETTINGS
from ..utils import print_version as _origin_print_version

print_version = partial(_origin_print_version, 'zoo.waifu2x')


@click.group(context_settings={**GLOBAL_CONTEXT_SETTINGS})
@click.option('-v', '--version', is_flag=True,
              callback=print_version, expose_value=False, is_eager=True,
              help="Utils with waifu2x models.")
def cli():
    pass  # pragma: no cover


@cli.command('sync', help='Export feature extract model as onnx.',
             context_settings={**GLOBAL_CONTEXT_SETTINGS})
@click.option('--repository', '-r', 'repository', type=str, default='deepghs/waifu2x_onnx',
              help='Repository to sync.', show_default=True)
def sync(repository: str):
    logging.try_init_root(logging.INFO)
    sync_to_huggingface(repository)


if __name__ == '__main__':
    cli()

zoo/waifu2x/sync.py

0 → 100644
+52 −0
Original line number Diff line number Diff line
import os
import re
import zipfile
from contextlib import contextmanager

from github import Github
from hbutils.system import TemporaryDirectory
from hfmirror.resource import LocalDirectoryResource
from hfmirror.storage import HuggingfaceStorage
from hfmirror.sync import SyncTask
from hfmirror.utils import download_file
from huggingface_hub import HfApi
from tqdm.auto import tqdm

MODEL_ASSET_PATTERN = re.compile(r'^waifu2x_onnx_models_(?P<version>[\s\S]*)\.zip$')


@contextmanager
def load_model_project():
    github_client = Github(os.environ['GH_ACCESS_TOKEN'])
    repo = github_client.get_repo('nagadomi/nunif')
    release = repo.get_release('0.0.0')
    with TemporaryDirectory() as ztd, TemporaryDirectory() as ptd:
        for asset in tqdm(release.get_assets()):
            matching = MODEL_ASSET_PATTERN.fullmatch(asset.name)
            if not matching:
                continue

            version = matching.group('version')
            url = asset.browser_download_url
            zip_file = os.path.join(ztd, asset.name)
            download_file(url, zip_file)

            version_dir = os.path.join(ptd, version)
            os.makedirs(version_dir, exist_ok=True)
            with zipfile.ZipFile(zip_file, 'r') as zf:
                zf.extractall(version_dir)

            break

        yield ptd


def sync_to_huggingface(repository: str = 'deepghs/waifu2x_onnx'):
    hf_client = HfApi(token=os.environ['HF_TOKEN'])
    hf_client.create_repo(repository, repo_type='model', exist_ok=True)
    storage = HuggingfaceStorage(repository, repo_type='model', hf_client=hf_client)

    with load_model_project() as ptd:
        resource = LocalDirectoryResource(ptd)
        task = SyncTask(resource, storage)
        task.sync()