Commit 16f0cd9e authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): try fix unittest, ci skip

parent 138ada24
Loading
Loading
Loading
Loading
+39 −3
Original line number Diff line number Diff line
import json
import os.path
import re
from pprint import pformat
from textwrap import indent

import numpy as np
import onnx
@@ -7,7 +10,7 @@ import onnxruntime
import pandas as pd
import torch
from ditk import logging
from hbutils.string import titleize
from hbutils.string import titleize, underscore
from hbutils.system import TemporaryDirectory
from hbutils.testing import vpip
from hfutils.operate import get_hf_client, upload_directory_as_directory
@@ -18,13 +21,14 @@ from thop import clever_format

from imgutils.data import load_image
from imgutils.preprocess import parse_torchvision_transforms
from imgutils.tagging import get_pixai_tags
from zoo.pixai_tagger.tags import load_tags
from zoo.utils import onnx_optimize, get_testfile, torch_model_profile_via_calflops
from .min_script import EndpointHandler
from .onnx import get_model


def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
def sync(src_repo: str, dst_repo: str, no_optimize: bool = False, show_current_results: bool = False):
    hf_client = get_hf_client()
    if not hf_client.repo_exists(repo_id=dst_repo, repo_type='model'):
        hf_client.create_repo(repo_id=dst_repo, repo_type='model', private=True)
@@ -212,6 +216,34 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
            print(f'```', file=f)
            print(f'', file=f)

            matching = re.fullmatch(r'^deepghs/pixai-tagger-(?P<version>[\s\S]+)-onnx$', dst_repo)
            model_name = matching.group('version') if matching else dst_repo
            print(f'```python', file=f)
            print(f'from imgutils.tagging import get_pixai_tags', file=f)
            print(f'', file=f)

            cate_names = tuple((cate_name for _, cate_name in sorted(d_category_names.items())))
            fmt_names = tuple([*cate_names, 'ips', 'ips_mapping'])
            var_names = tuple(map(underscore, fmt_names))
            print(f'{", ".join(var_names)} = get_pixai_tags(', file=f)
            print(f'    {sample_input_url!r},', file=f)
            print(f'    model_name={model_name!r},', file=f)
            print(f'    fmt={fmt_names!r},', file=f)
            print(f')', file=f)
            print(f'', file=f)

            var_values = get_pixai_tags(sample_input_file, model_name=model_name, fmt=fmt_names)
            for varname, fmtname, varvalue in zip(var_names, fmt_names, var_values):
                print(f'print({varname})', file=f)
                print(indent(
                    pformat(varvalue, sort_dicts=False),
                    prefix='# ',
                ), file=f)

                print(f'', file=f)
            print(f'```', file=f)
            print(f'', file=f)

        upload_directory_as_directory(
            repo_id=dst_repo,
            repo_type='model',
@@ -231,7 +263,11 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):

if __name__ == '__main__':
    logging.try_init_root(logging.INFO)
    # matching = re.fullmatch(r'^deepghs/pixai-tagger-(?P<version>[\s\S]+)-onnx$', 'deepghs/pixai-tagger-v0.9-onnx')
    # print(matching)
    # print(matching.group('version'))
    sync(
        src_repo='pixai-labs/pixai-tagger-v0.9',
        dst_repo='deepghs/pixai-tagger-v0.9-onnx'
        dst_repo='deepghs/pixai-tagger-v0.9-onnx',
        show_current_results=True,
    )