Commit 859e829c authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add wd14 syncer, ci skip

parent 090c2810
Loading
Loading
Loading
Loading
+44 −0
Original line number Diff line number Diff line
name: Sync WD14 Models

on:
  #  push:
  workflow_dispatch:
  schedule:
    - cron: '30 18 * * *'

jobs:
  sync:
    name: Sync Waifu2x ONNX
    runs-on: ${{ matrix.os }}
    strategy:
      fail-fast: false
      matrix:
        os:
          - 'ubuntu-latest'
        python-version:
          - '3.8'

    steps:
      - name: Checkout code
        uses: actions/checkout@v3
        with:
          fetch-depth: 20
      - name: Set up python ${{ matrix.python-version }}
        uses: actions/setup-python@v4
        with:
          python-version: ${{ matrix.python-version }}
      - name: Set up python dependences
        run: |
          pip install --upgrade pip
          pip install --upgrade flake8 setuptools wheel twine
          if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
          if [ -f requirements-build.txt ]; then pip install -r requirements-build.txt; fi
          if [ -f requirements-test.txt ]; then pip install -r requirements-test.txt; fi
          if [ -f requirements-test.txt ]; then pip install -r requirements-zoo.txt; fi
          pip install --upgrade build
      - name: Sync Models
        env:
          HF_TOKEN: ${{ secrets.HF_TOKEN }}
          GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }}
        run: |
          python -m zoo.wd14.sync

zoo/wd14/__init__.py

0 → 100644
+0 −0

Empty file added.

zoo/wd14/sync.py

0 → 100644
+152 −0
Original line number Diff line number Diff line
import os.path
import re
from functools import lru_cache

import numpy as np
import onnx
import onnxruntime
import pandas as pd
from ditk import logging
from hbutils.string import plural_word
from hbutils.system import TemporaryDirectory
from hfutils.operate import upload_directory_as_directory
from huggingface_hub import hf_hub_download
from onnx.helper import make_tensor_value_info
from tqdm import tqdm

from imgutils.tagging.wd14 import MODEL_NAMES

logging.try_init_root(logging.INFO)


@lru_cache()
def _get_model_file(name) -> str:
    return hf_hub_download(
        repo_id=MODEL_NAMES[name],
        filename='model.onnx'
    )


@lru_cache()
def _get_model_tags_length(name) -> int:
    return len(pd.read_csv(hf_hub_download(
        repo_id=MODEL_NAMES[name],
        filename='selected_tags.csv',
    )))


def _seg_split(text):
    return tuple(filter(bool, re.split(r'[./]+', text)))


_FC_KEYWORDS_FOR_V2 = {'predictions_dense'}
_FC_NODE_PREFIXES_FOR_V3 = {
    "SwinV2_v3": ('core_model', 'head', 'fc'),
    "ConvNext_v3": ('core_model', 'head', 'fc'),
    "ViT_v3": ('core_model', 'head'),
}

if __name__ == '__main__':
    with TemporaryDirectory() as td:
        records = []
        for model_name in tqdm(MODEL_NAMES):
            model_file = _get_model_file(model_name)
            logging.info(f'Model name: {model_name!r}, model file: {model_file!r}')
            logging.info(f'Loading model {model_name!r} ...')
            model = onnx.load(model_file)
            embs_outputs = []
            if model_name in _FC_NODE_PREFIXES_FOR_V3:
                prefix = _FC_NODE_PREFIXES_FOR_V3[model_name]


                def _is_fc(name):
                    return _seg_split(name)[:len(prefix)] == prefix
            else:
                def _is_fc(name):
                    return any(seg in _FC_KEYWORDS_FOR_V2 for seg in _seg_split(name))

            for node in model.graph.node:
                if _is_fc(node.name):
                    for input_name in node.input:
                        if not _is_fc(input_name):
                            logging.info(f'Input {input_name!r} for fc layer {node.name!r}.')
                            embs_outputs.append(input_name)

            logging.info(f'Embedding outputs: {embs_outputs!r}.')
            assert len(embs_outputs) == 1, f'Outputs: {embs_outputs!r}'
            # make_tensor_value_info(name=embs_outputs[0], elem_type=onnx.TensorProto.FLOAT, )
            model.graph.output.extend([onnx.ValueInfoProto(name=embs_outputs[0])])

            logging.info('Analysing via onnxruntime ...')
            session = onnxruntime.InferenceSession(model.SerializeToString())
            input_data = np.random.randn(1, 448, 448, 3).astype(np.float32)
            assert len(session.get_inputs()) == 1
            assert len(session.get_outputs()) == 2
            assert session.get_outputs()[1].name == embs_outputs[0]

            tags_data, embeddings = session.run([], {session.get_inputs()[0].name: input_data})
            logging.info(f'Tag output, shape: {tags_data.shape!r}, dtype: {tags_data.dtype!r}')
            logging.info(f'Embeddings output, shape: {embeddings.shape!r}, dtype: {embeddings.dtype!r}')
            assert tags_data.shape == (1, _get_model_tags_length(model_name))
            assert len(embeddings.shape) == 2 and embeddings.shape[0] == 1
            emb_width = embeddings.shape[-1]

            logging.info('Remaking model ...')
            model = onnx.load(model_file)
            model.graph.output.extend([make_tensor_value_info(
                name=embs_outputs[0],
                elem_type=onnx.TensorProto.FLOAT,
                shape=embeddings.shape,
            )])

            onnx_file = os.path.join(td, MODEL_NAMES[model_name], 'model.onnx')
            os.makedirs(os.path.dirname(onnx_file), exist_ok=True)
            onnx.save_model(model, onnx_file)

            logging.info(f'Loading and testing for the exported model {onnx_file!r}.')
            session = onnxruntime.InferenceSession(onnx_file)
            assert len(session.get_inputs()) == 1
            assert len(session.get_outputs()) == 2
            assert session.get_outputs()[1].name == embs_outputs[0]
            assert session.get_outputs()[1].shape == [1, emb_width]

            tags_data, embeddings = session.run([], {session.get_inputs()[0].name: input_data})
            logging.info(f'Tag output, shape: {tags_data.shape!r}, dtype: {tags_data.dtype!r}')
            logging.info(f'Embeddings output, shape: {embeddings.shape!r}, dtype: {embeddings.dtype!r}')
            assert tags_data.shape == (1, _get_model_tags_length(model_name))
            assert embeddings.shape == (1, emb_width)

            records.append({
                'Name': model_name,
                'Source Repository': f'[{MODEL_NAMES[model_name]}](https://huggingface.co/{MODEL_NAMES[model_name]})',
                'Tags Count': _get_model_tags_length(model_name),
                'Embedding Width': emb_width,
            })
            _get_model_file.cache_clear()
            _get_model_tags_length.cache_clear()

        df_records = pd.DataFrame(records)
        with open(os.path.join(td, 'README.md'), 'w') as f:
            print('---', file=f)
            print('license: apache-2.0', file=f)
            print('language:', file=f)
            print('- en', file=f)
            print('---', file=f)
            print('', file=f)

            print(
                f'This is onnx models based on [SmilingWolf](https://huggingface.co/SmilingWolf)\'s wd14 anime taggers, '
                f'which added the embeddings output as the second output.', file=f)
            print(f'', file=f)
            print(f'{plural_word(len(df_records), "model")} in total: ', file=f)
            print(f'', file=f)
            print(df_records.to_markdown(index=False), file=f)

        upload_directory_as_directory(
            repo_id='deepghs/wd14_tagger_with_embeddings',
            repo_type='model',
            local_directory=td,
            path_in_repo='.',
            message=f'Upload {plural_word(len(df_records), "models")}',
            clear=True,
        )