Loading requirements-zoo.txt +2 −1 Original line number Diff line number Diff line Loading @@ -28,3 +28,4 @@ git+https://github.com/deepghs/waifuc.git@main#egg=waifuc pyquery httpx onnxslim==0.1.32 calflops No newline at end of file zoo/pixai_tagger/export.py +61 −13 Original line number Diff line number Diff line Loading @@ -7,14 +7,19 @@ import onnxruntime import pandas as pd import torch from ditk import logging from hbutils.string import titleize from hbutils.system import TemporaryDirectory from hbutils.testing import vpip from hfutils.operate import get_hf_client, upload_directory_as_directory from hfutils.repository import hf_hub_repo_url from hfutils.repository import hf_hub_repo_url, hf_hub_repo_file_url from hfutils.utils import hf_normpath from huggingface_hub import hf_hub_url from thop import clever_format from imgutils.data import load_image from imgutils.preprocess import parse_torchvision_transforms from zoo.pixai_tagger.tags import load_tags from zoo.utils import onnx_optimize, get_testfile from zoo.utils import onnx_optimize, get_testfile, torch_model_profile_via_calflops from .min_script import EndpointHandler from .onnx import get_model Loading @@ -25,7 +30,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): hf_client.create_repo(repo_id=dst_repo, repo_type='model', private=True) handler = EndpointHandler(repo_id=src_repo) meta_info = {} with TemporaryDirectory() as upload_dir: preprocessor = handler.transform preprocessor_file = os.path.join(upload_dir, 'preprocessor.json') Loading Loading @@ -58,16 +63,16 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): logging.info(f'Tags:\n{df_tags}') df_tags.to_csv(os.path.join(upload_dir, 'selected_tags.csv'), index=False) d_category_names = { 0: 'general', 4: 'character', } with open(os.path.join(upload_dir, 'categories.json'), 'w') as f: json.dump([ { "category": 0, "name": "general" }, { "category": 4, "name": "character" }, "category": cate_id, "name": cate_name, } for cate_id, cate_name in sorted(d_category_names.items()) ], f, sort_keys=True, ensure_ascii=False, indent=4) df_th = pd.DataFrame([ {'category': 0, 'name': 'general', 'threshold': handler.default_general_threshold}, Loading @@ -77,6 +82,15 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): dummy_image = load_image(get_testfile('6125785.jpg'), mode='RGB', force_background='white') dummy_input = handler.transform(dummy_image).unsqueeze(0).to(handler.device) flops, params, macs = torch_model_profile_via_calflops(model=handler.model, input_=dummy_input) meta_info['flops'] = flops meta_info['params'] = params meta_info['macs'] = macs new_meta_file = os.path.join(upload_dir, 'meta.json') logging.info(f'Saving metadata to {new_meta_file!r} ...') with open(new_meta_file, 'w') as f: json.dump(meta_info, f, indent=4, sort_keys=True, ensure_ascii=False) wrapped_model, (conv_features, _) = get_model(handler.model, dummy_input) conv_features = conv_features.detach().cpu() onnx_filename = os.path.join(upload_dir, 'model.onnx') Loading Loading @@ -135,17 +149,51 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): print('---', file=f) print('', file=f) print(f'PixAI-Tagger ONNX Version for {src_repo}', file=f) print(f'# ONNX Version for {src_repo}', file=f) print(f'', file=f) print(f'This is the ONNX-exported version of PixAI\'s tagger ' f'[{src_repo}]({hf_hub_repo_url(repo_id=src_repo, repo_type="model")}).', file=f) print(f'', file=f) print(f'# How To Use', file=f) s_flops, s_params, s_macs = clever_format([flops, params, macs], "%.1f") print(f'## Model Details', file=f) print(f'', file=f) print(f'- **Model Type:** Multilabel Image classification / feature backbone', file=f) print(f'- **Model Stats:**', file=f) print(f' - Params: {s_params}', file=f) print(f' - FLOPs / MACs: {s_flops} / {s_macs}', file=f) print(f' - Image size: {dummy_input.shape[-1]} x {dummy_input.shape[-2]}', file=f) print(f' - Tags Count: {len(df_tags)}', file=f) for category in sorted(set(df_tags['category'])): print(f' - {titleize(d_category_names[category])} (#{category}) Tags Count: ' f'{len(df_tags[df_tags["category"] == category])}', file=f) print(f'', file=f) print(f'## How to Use', file=f) print(f'', file=f) imgutils_version = str(vpip('dghs-imgutils')._actual_version) sample_input = dummy_image if min(sample_input.width, sample_input.height) > 640: r = min(sample_input.width, sample_input.height) / 640 new_width = int(sample_input.width / r) new_height = int(sample_input.height / r) sample_input = sample_input.resize((new_width, new_height)) sample_input_file = os.path.join(upload_dir, 'sample.webp') sample_input_relfile = hf_normpath(os.path.relpath(sample_input_file, upload_dir)) sample_input.save(sample_input_file) sample_input_url = hf_hub_url(repo_id=dst_repo, repo_type='model', filename=sample_input_relfile) sample_input_page_url = hf_hub_repo_file_url(repo_id=dst_repo, repo_type='model', path=sample_input_relfile) print(f'We provided a sample image for our code samples, ' f'you can find it [here]({sample_input_page_url}).', file=f) print(f'', file=f) print(f'Install [dghs-imgutils](https://github.com/deepghs/imgutils) with the following command', file=f) print(f'', file=f) print(f'```shell', file=f) print(f'pip install -U dghs-imgutils', file=f) print(f'pip install \'dghs-imgutils>={imgutils_version}\' torch huggingface_hub timm pillow pandas', file=f) print(f'```', file=f) print(f'', file=f) Loading zoo/utils/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -2,4 +2,5 @@ from .cli import GLOBAL_CONTEXT_SETTINGS, print_version from .lr import get_init_lr, get_dynamic_lr_scheduler, LRTyping from .onnx import onnx_quick_export from .optimize import onnx_optimize from .profile import torch_model_profile_via_thop, torch_model_profile_via_calflops from .testfile import get_testfile zoo/utils/profile.py 0 → 100644 +39 −0 Original line number Diff line number Diff line import torch from ditk import logging from thop import profile, clever_format def torch_model_profile_via_thop(model, input_): with torch.no_grad(): flops, params = profile(model, (input_,)) s_flops, s_params = clever_format([flops, params], "%.1f") logging.info(f'Params: {s_params}, FLOPs: {s_flops}.') return flops, params def torch_model_profile_via_calflops(model, input_): from calflops import calculate_flops flops, macs, params = calculate_flops( model=model, input_shape=tuple(input_.shape), output_as_string=False, print_detailed=False, # output_as_string=True, # output_precision=4 ) s_flops, s_params, s_macs = clever_format([flops, params, macs], "%.1f") logging.info(f'Params: {s_params}, FLOPs: {s_flops}, MACs: {s_macs}.') return flops, params, macs if __name__ == '__main__': logging.try_init_root(level=logging.INFO) from timm import create_model # model = create_model('hf-hub:animetimm/swinv2_base_window8_256.dbv4-full', pretrained=False) model = create_model('caformer_b36.sail_in22k_ft_in1k_384', pretrained=False) dummy_input = torch.randn(1, 3, 448, 448) print(torch_model_profile_via_thop(model, dummy_input)) print(torch_model_profile_via_calflops(model, dummy_input)) Loading
requirements-zoo.txt +2 −1 Original line number Diff line number Diff line Loading @@ -28,3 +28,4 @@ git+https://github.com/deepghs/waifuc.git@main#egg=waifuc pyquery httpx onnxslim==0.1.32 calflops No newline at end of file
zoo/pixai_tagger/export.py +61 −13 Original line number Diff line number Diff line Loading @@ -7,14 +7,19 @@ import onnxruntime import pandas as pd import torch from ditk import logging from hbutils.string import titleize from hbutils.system import TemporaryDirectory from hbutils.testing import vpip from hfutils.operate import get_hf_client, upload_directory_as_directory from hfutils.repository import hf_hub_repo_url from hfutils.repository import hf_hub_repo_url, hf_hub_repo_file_url from hfutils.utils import hf_normpath from huggingface_hub import hf_hub_url from thop import clever_format from imgutils.data import load_image from imgutils.preprocess import parse_torchvision_transforms from zoo.pixai_tagger.tags import load_tags from zoo.utils import onnx_optimize, get_testfile from zoo.utils import onnx_optimize, get_testfile, torch_model_profile_via_calflops from .min_script import EndpointHandler from .onnx import get_model Loading @@ -25,7 +30,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): hf_client.create_repo(repo_id=dst_repo, repo_type='model', private=True) handler = EndpointHandler(repo_id=src_repo) meta_info = {} with TemporaryDirectory() as upload_dir: preprocessor = handler.transform preprocessor_file = os.path.join(upload_dir, 'preprocessor.json') Loading Loading @@ -58,16 +63,16 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): logging.info(f'Tags:\n{df_tags}') df_tags.to_csv(os.path.join(upload_dir, 'selected_tags.csv'), index=False) d_category_names = { 0: 'general', 4: 'character', } with open(os.path.join(upload_dir, 'categories.json'), 'w') as f: json.dump([ { "category": 0, "name": "general" }, { "category": 4, "name": "character" }, "category": cate_id, "name": cate_name, } for cate_id, cate_name in sorted(d_category_names.items()) ], f, sort_keys=True, ensure_ascii=False, indent=4) df_th = pd.DataFrame([ {'category': 0, 'name': 'general', 'threshold': handler.default_general_threshold}, Loading @@ -77,6 +82,15 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): dummy_image = load_image(get_testfile('6125785.jpg'), mode='RGB', force_background='white') dummy_input = handler.transform(dummy_image).unsqueeze(0).to(handler.device) flops, params, macs = torch_model_profile_via_calflops(model=handler.model, input_=dummy_input) meta_info['flops'] = flops meta_info['params'] = params meta_info['macs'] = macs new_meta_file = os.path.join(upload_dir, 'meta.json') logging.info(f'Saving metadata to {new_meta_file!r} ...') with open(new_meta_file, 'w') as f: json.dump(meta_info, f, indent=4, sort_keys=True, ensure_ascii=False) wrapped_model, (conv_features, _) = get_model(handler.model, dummy_input) conv_features = conv_features.detach().cpu() onnx_filename = os.path.join(upload_dir, 'model.onnx') Loading Loading @@ -135,17 +149,51 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): print('---', file=f) print('', file=f) print(f'PixAI-Tagger ONNX Version for {src_repo}', file=f) print(f'# ONNX Version for {src_repo}', file=f) print(f'', file=f) print(f'This is the ONNX-exported version of PixAI\'s tagger ' f'[{src_repo}]({hf_hub_repo_url(repo_id=src_repo, repo_type="model")}).', file=f) print(f'', file=f) print(f'# How To Use', file=f) s_flops, s_params, s_macs = clever_format([flops, params, macs], "%.1f") print(f'## Model Details', file=f) print(f'', file=f) print(f'- **Model Type:** Multilabel Image classification / feature backbone', file=f) print(f'- **Model Stats:**', file=f) print(f' - Params: {s_params}', file=f) print(f' - FLOPs / MACs: {s_flops} / {s_macs}', file=f) print(f' - Image size: {dummy_input.shape[-1]} x {dummy_input.shape[-2]}', file=f) print(f' - Tags Count: {len(df_tags)}', file=f) for category in sorted(set(df_tags['category'])): print(f' - {titleize(d_category_names[category])} (#{category}) Tags Count: ' f'{len(df_tags[df_tags["category"] == category])}', file=f) print(f'', file=f) print(f'## How to Use', file=f) print(f'', file=f) imgutils_version = str(vpip('dghs-imgutils')._actual_version) sample_input = dummy_image if min(sample_input.width, sample_input.height) > 640: r = min(sample_input.width, sample_input.height) / 640 new_width = int(sample_input.width / r) new_height = int(sample_input.height / r) sample_input = sample_input.resize((new_width, new_height)) sample_input_file = os.path.join(upload_dir, 'sample.webp') sample_input_relfile = hf_normpath(os.path.relpath(sample_input_file, upload_dir)) sample_input.save(sample_input_file) sample_input_url = hf_hub_url(repo_id=dst_repo, repo_type='model', filename=sample_input_relfile) sample_input_page_url = hf_hub_repo_file_url(repo_id=dst_repo, repo_type='model', path=sample_input_relfile) print(f'We provided a sample image for our code samples, ' f'you can find it [here]({sample_input_page_url}).', file=f) print(f'', file=f) print(f'Install [dghs-imgutils](https://github.com/deepghs/imgutils) with the following command', file=f) print(f'', file=f) print(f'```shell', file=f) print(f'pip install -U dghs-imgutils', file=f) print(f'pip install \'dghs-imgutils>={imgutils_version}\' torch huggingface_hub timm pillow pandas', file=f) print(f'```', file=f) print(f'', file=f) Loading
zoo/utils/__init__.py +1 −0 Original line number Diff line number Diff line Loading @@ -2,4 +2,5 @@ from .cli import GLOBAL_CONTEXT_SETTINGS, print_version from .lr import get_init_lr, get_dynamic_lr_scheduler, LRTyping from .onnx import onnx_quick_export from .optimize import onnx_optimize from .profile import torch_model_profile_via_thop, torch_model_profile_via_calflops from .testfile import get_testfile
zoo/utils/profile.py 0 → 100644 +39 −0 Original line number Diff line number Diff line import torch from ditk import logging from thop import profile, clever_format def torch_model_profile_via_thop(model, input_): with torch.no_grad(): flops, params = profile(model, (input_,)) s_flops, s_params = clever_format([flops, params], "%.1f") logging.info(f'Params: {s_params}, FLOPs: {s_flops}.') return flops, params def torch_model_profile_via_calflops(model, input_): from calflops import calculate_flops flops, macs, params = calculate_flops( model=model, input_shape=tuple(input_.shape), output_as_string=False, print_detailed=False, # output_as_string=True, # output_precision=4 ) s_flops, s_params, s_macs = clever_format([flops, params, macs], "%.1f") logging.info(f'Params: {s_params}, FLOPs: {s_flops}, MACs: {s_macs}.') return flops, params, macs if __name__ == '__main__': logging.try_init_root(level=logging.INFO) from timm import create_model # model = create_model('hf-hub:animetimm/swinv2_base_window8_256.dbv4-full', pretrained=False) model = create_model('caformer_b36.sail_in22k_ft_in1k_384', pretrained=False) dummy_input = torch.randn(1, 3, 448, 448) print(torch_model_profile_via_thop(model, dummy_input)) print(torch_model_profile_via_calflops(model, dummy_input))