Loading zoo/pixai_tagger/export.py +39 −3 Original line number Diff line number Diff line import json import os.path import re from pprint import pformat from textwrap import indent import numpy as np import onnx Loading @@ -7,7 +10,7 @@ import onnxruntime import pandas as pd import torch from ditk import logging from hbutils.string import titleize from hbutils.string import titleize, underscore from hbutils.system import TemporaryDirectory from hbutils.testing import vpip from hfutils.operate import get_hf_client, upload_directory_as_directory Loading @@ -18,13 +21,14 @@ from thop import clever_format from imgutils.data import load_image from imgutils.preprocess import parse_torchvision_transforms from imgutils.tagging import get_pixai_tags from zoo.pixai_tagger.tags import load_tags from zoo.utils import onnx_optimize, get_testfile, torch_model_profile_via_calflops from .min_script import EndpointHandler from .onnx import get_model def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): def sync(src_repo: str, dst_repo: str, no_optimize: bool = False, show_current_results: bool = False): hf_client = get_hf_client() if not hf_client.repo_exists(repo_id=dst_repo, repo_type='model'): hf_client.create_repo(repo_id=dst_repo, repo_type='model', private=True) Loading Loading @@ -212,6 +216,34 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): print(f'```', file=f) print(f'', file=f) matching = re.fullmatch(r'^deepghs/pixai-tagger-(?P<version>[\s\S]+)-onnx$', dst_repo) model_name = matching.group('version') if matching else dst_repo print(f'```python', file=f) print(f'from imgutils.tagging import get_pixai_tags', file=f) print(f'', file=f) cate_names = tuple((cate_name for _, cate_name in sorted(d_category_names.items()))) fmt_names = tuple([*cate_names, 'ips', 'ips_mapping']) var_names = tuple(map(underscore, fmt_names)) print(f'{", ".join(var_names)} = get_pixai_tags(', file=f) print(f' {sample_input_url!r},', file=f) print(f' model_name={model_name!r},', file=f) print(f' fmt={fmt_names!r},', file=f) print(f')', file=f) print(f'', file=f) var_values = get_pixai_tags(sample_input_file, model_name=model_name, fmt=fmt_names) for varname, fmtname, varvalue in zip(var_names, fmt_names, var_values): print(f'print({varname})', file=f) print(indent( pformat(varvalue, sort_dicts=False), prefix='# ', ), file=f) print(f'', file=f) print(f'```', file=f) print(f'', file=f) upload_directory_as_directory( repo_id=dst_repo, repo_type='model', Loading @@ -231,7 +263,11 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): if __name__ == '__main__': logging.try_init_root(logging.INFO) # matching = re.fullmatch(r'^deepghs/pixai-tagger-(?P<version>[\s\S]+)-onnx$', 'deepghs/pixai-tagger-v0.9-onnx') # print(matching) # print(matching.group('version')) sync( src_repo='pixai-labs/pixai-tagger-v0.9', dst_repo='deepghs/pixai-tagger-v0.9-onnx' dst_repo='deepghs/pixai-tagger-v0.9-onnx', show_current_results=True, ) Loading
zoo/pixai_tagger/export.py +39 −3 Original line number Diff line number Diff line import json import os.path import re from pprint import pformat from textwrap import indent import numpy as np import onnx Loading @@ -7,7 +10,7 @@ import onnxruntime import pandas as pd import torch from ditk import logging from hbutils.string import titleize from hbutils.string import titleize, underscore from hbutils.system import TemporaryDirectory from hbutils.testing import vpip from hfutils.operate import get_hf_client, upload_directory_as_directory Loading @@ -18,13 +21,14 @@ from thop import clever_format from imgutils.data import load_image from imgutils.preprocess import parse_torchvision_transforms from imgutils.tagging import get_pixai_tags from zoo.pixai_tagger.tags import load_tags from zoo.utils import onnx_optimize, get_testfile, torch_model_profile_via_calflops from .min_script import EndpointHandler from .onnx import get_model def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): def sync(src_repo: str, dst_repo: str, no_optimize: bool = False, show_current_results: bool = False): hf_client = get_hf_client() if not hf_client.repo_exists(repo_id=dst_repo, repo_type='model'): hf_client.create_repo(repo_id=dst_repo, repo_type='model', private=True) Loading Loading @@ -212,6 +216,34 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): print(f'```', file=f) print(f'', file=f) matching = re.fullmatch(r'^deepghs/pixai-tagger-(?P<version>[\s\S]+)-onnx$', dst_repo) model_name = matching.group('version') if matching else dst_repo print(f'```python', file=f) print(f'from imgutils.tagging import get_pixai_tags', file=f) print(f'', file=f) cate_names = tuple((cate_name for _, cate_name in sorted(d_category_names.items()))) fmt_names = tuple([*cate_names, 'ips', 'ips_mapping']) var_names = tuple(map(underscore, fmt_names)) print(f'{", ".join(var_names)} = get_pixai_tags(', file=f) print(f' {sample_input_url!r},', file=f) print(f' model_name={model_name!r},', file=f) print(f' fmt={fmt_names!r},', file=f) print(f')', file=f) print(f'', file=f) var_values = get_pixai_tags(sample_input_file, model_name=model_name, fmt=fmt_names) for varname, fmtname, varvalue in zip(var_names, fmt_names, var_values): print(f'print({varname})', file=f) print(indent( pformat(varvalue, sort_dicts=False), prefix='# ', ), file=f) print(f'', file=f) print(f'```', file=f) print(f'', file=f) upload_directory_as_directory( repo_id=dst_repo, repo_type='model', Loading @@ -231,7 +263,11 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): if __name__ == '__main__': logging.try_init_root(logging.INFO) # matching = re.fullmatch(r'^deepghs/pixai-tagger-(?P<version>[\s\S]+)-onnx$', 'deepghs/pixai-tagger-v0.9-onnx') # print(matching) # print(matching.group('version')) sync( src_repo='pixai-labs/pixai-tagger-v0.9', dst_repo='deepghs/pixai-tagger-v0.9-onnx' dst_repo='deepghs/pixai-tagger-v0.9-onnx', show_current_results=True, )