Commit 3ae7b5fc authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): new code support, ci skip

parent c92aa23d
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -4,6 +4,10 @@ on:
  #  push:
  workflow_dispatch:
    inputs:
      models:
        description: 'Models To Make'
        type: str
        default: ''
      tag_lazy_mode:
        description: 'Enable Tag Lazy Mode'
        type: boolean
@@ -55,5 +59,6 @@ jobs:
        env:
          HF_TOKEN: ${{ secrets.HF_TOKEN }}
          GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }}
          MODELS: ${{ github.event.inputs.models || '' }}
        run: |
          python -m zoo.wd14.sync
+2 −0
Original line number Diff line number Diff line
@@ -27,12 +27,14 @@ CONV_V3_MODEL_REPO = 'SmilingWolf/wd-convnext-tagger-v3'
SWIN_V3_MODEL_REPO = 'SmilingWolf/wd-swinv2-tagger-v3'
VIT_V3_MODEL_REPO = 'SmilingWolf/wd-vit-tagger-v3'
VIT_LARGE_MODEL_REPO = 'SmilingWolf/wd-vit-large-tagger-v3'
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"

_IS_V3_SUPPORT = VersionInfo(onnxruntime.__version__) >= '1.17'

MODEL_NAMES = {
    "EVA02_Large": EVA02_LARGE_MODEL_DSV3_REPO,
    "ViT_Large": VIT_LARGE_MODEL_REPO,

    "SwinV2": SWIN_MODEL_REPO,
+14 −3
Original line number Diff line number Diff line
import os.path
import re
from functools import lru_cache
from typing import List, Optional

import numpy as np
import onnx
@@ -48,16 +49,24 @@ _FC_NODE_PREFIXES_FOR_V3 = {
    "ConvNext_v3": ('core_model', 'head', 'fc'),
    "ViT_v3": ('core_model', 'head'),
    "ViT_Large": ('core_model', 'head'),
    "EVA02_Large": ('core_model', 'head'),
}


def sync(tag_lazy_mode: bool = False):
def sync(tag_lazy_mode: bool = False, models: Optional[List[str]] = None):
    hf_fs = get_hf_fs()

    if models:
        _make_all = False
        _model_names = models
    else:
        _make_all = True
        _model_names = MODEL_NAMES

    import onnxruntime
    with TemporaryDirectory() as td:
        records = []
        for model_name in tqdm(MODEL_NAMES):
        for model_name in tqdm(_model_names):
            model_file = _get_model_file(model_name)
            logging.info(f'Model name: {model_name!r}, model file: {model_file!r}')
            logging.info(f'Loading model {model_name!r} ...')
@@ -171,11 +180,13 @@ def sync(tag_lazy_mode: bool = False):
            local_directory=td,
            path_in_repo='.',
            message=f'Upload {plural_word(len(df_records), "models")}',
            clear=True,
            clear=True if _make_all else False,
        )


if __name__ == '__main__':
    _MODELS = list(filter(bool, re.split('[,\s]+', os.environ.get('MODELS') or '')))
    sync(
        tag_lazy_mode=bool(os.environ.get('TAG_LAZY_MODE')),
        models=_MODELS if _MODELS else None,
    )