Commit ff0f6dc7 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add more profiles, ci skip

parent 4e4e0ab4
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -28,3 +28,4 @@ git+https://github.com/deepghs/waifuc.git@main#egg=waifuc
pyquery
httpx
onnxslim==0.1.32
calflops
 No newline at end of file
+61 −13
Original line number Diff line number Diff line
@@ -7,14 +7,19 @@ import onnxruntime
import pandas as pd
import torch
from ditk import logging
from hbutils.string import titleize
from hbutils.system import TemporaryDirectory
from hbutils.testing import vpip
from hfutils.operate import get_hf_client, upload_directory_as_directory
from hfutils.repository import hf_hub_repo_url
from hfutils.repository import hf_hub_repo_url, hf_hub_repo_file_url
from hfutils.utils import hf_normpath
from huggingface_hub import hf_hub_url
from thop import clever_format

from imgutils.data import load_image
from imgutils.preprocess import parse_torchvision_transforms
from zoo.pixai_tagger.tags import load_tags
from zoo.utils import onnx_optimize, get_testfile
from zoo.utils import onnx_optimize, get_testfile, torch_model_profile_via_calflops
from .min_script import EndpointHandler
from .onnx import get_model

@@ -25,7 +30,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
        hf_client.create_repo(repo_id=dst_repo, repo_type='model', private=True)

    handler = EndpointHandler(repo_id=src_repo)

    meta_info = {}
    with TemporaryDirectory() as upload_dir:
        preprocessor = handler.transform
        preprocessor_file = os.path.join(upload_dir, 'preprocessor.json')
@@ -58,16 +63,16 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
        logging.info(f'Tags:\n{df_tags}')
        df_tags.to_csv(os.path.join(upload_dir, 'selected_tags.csv'), index=False)

        d_category_names = {
            0: 'general',
            4: 'character',
        }
        with open(os.path.join(upload_dir, 'categories.json'), 'w') as f:
            json.dump([
                {
                    "category": 0,
                    "name": "general"
                },
                {
                    "category": 4,
                    "name": "character"
                },
                    "category": cate_id,
                    "name": cate_name,
                } for cate_id, cate_name in sorted(d_category_names.items())
            ], f, sort_keys=True, ensure_ascii=False, indent=4)
        df_th = pd.DataFrame([
            {'category': 0, 'name': 'general', 'threshold': handler.default_general_threshold},
@@ -77,6 +82,15 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):

        dummy_image = load_image(get_testfile('6125785.jpg'), mode='RGB', force_background='white')
        dummy_input = handler.transform(dummy_image).unsqueeze(0).to(handler.device)
        flops, params, macs = torch_model_profile_via_calflops(model=handler.model, input_=dummy_input)
        meta_info['flops'] = flops
        meta_info['params'] = params
        meta_info['macs'] = macs
        new_meta_file = os.path.join(upload_dir, 'meta.json')
        logging.info(f'Saving metadata to {new_meta_file!r} ...')
        with open(new_meta_file, 'w') as f:
            json.dump(meta_info, f, indent=4, sort_keys=True, ensure_ascii=False)

        wrapped_model, (conv_features, _) = get_model(handler.model, dummy_input)
        conv_features = conv_features.detach().cpu()
        onnx_filename = os.path.join(upload_dir, 'model.onnx')
@@ -135,17 +149,51 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
            print('---', file=f)
            print('', file=f)

            print(f'PixAI-Tagger ONNX Version for {src_repo}', file=f)
            print(f'# ONNX Version for {src_repo}', file=f)
            print(f'', file=f)

            print(f'This is the ONNX-exported version of PixAI\'s tagger '
                  f'[{src_repo}]({hf_hub_repo_url(repo_id=src_repo, repo_type="model")}).', file=f)
            print(f'', file=f)

            print(f'# How To Use', file=f)
            s_flops, s_params, s_macs = clever_format([flops, params, macs], "%.1f")
            print(f'## Model Details', file=f)
            print(f'', file=f)
            print(f'- **Model Type:** Multilabel Image classification / feature backbone', file=f)
            print(f'- **Model Stats:**', file=f)
            print(f'  - Params: {s_params}', file=f)
            print(f'  - FLOPs / MACs: {s_flops} / {s_macs}', file=f)
            print(f'  - Image size: {dummy_input.shape[-1]} x {dummy_input.shape[-2]}', file=f)
            print(f'  - Tags Count: {len(df_tags)}', file=f)
            for category in sorted(set(df_tags['category'])):
                print(f'    - {titleize(d_category_names[category])} (#{category}) Tags Count: '
                      f'{len(df_tags[df_tags["category"] == category])}', file=f)
            print(f'', file=f)

            print(f'## How to Use', file=f)
            print(f'', file=f)

            imgutils_version = str(vpip('dghs-imgutils')._actual_version)
            sample_input = dummy_image
            if min(sample_input.width, sample_input.height) > 640:
                r = min(sample_input.width, sample_input.height) / 640
                new_width = int(sample_input.width / r)
                new_height = int(sample_input.height / r)
                sample_input = sample_input.resize((new_width, new_height))
            sample_input_file = os.path.join(upload_dir, 'sample.webp')
            sample_input_relfile = hf_normpath(os.path.relpath(sample_input_file, upload_dir))
            sample_input.save(sample_input_file)
            sample_input_url = hf_hub_url(repo_id=dst_repo, repo_type='model', filename=sample_input_relfile)
            sample_input_page_url = hf_hub_repo_file_url(repo_id=dst_repo, repo_type='model', path=sample_input_relfile)

            print(f'We provided a sample image for our code samples, '
                  f'you can find it [here]({sample_input_page_url}).', file=f)
            print(f'', file=f)

            print(f'Install [dghs-imgutils](https://github.com/deepghs/imgutils) with the following command', file=f)
            print(f'', file=f)
            print(f'```shell', file=f)
            print(f'pip install -U dghs-imgutils', file=f)
            print(f'pip install \'dghs-imgutils>={imgutils_version}\' torch huggingface_hub timm pillow pandas', file=f)
            print(f'```', file=f)
            print(f'', file=f)

+1 −0
Original line number Diff line number Diff line
@@ -2,4 +2,5 @@ from .cli import GLOBAL_CONTEXT_SETTINGS, print_version
from .lr import get_init_lr, get_dynamic_lr_scheduler, LRTyping
from .onnx import onnx_quick_export
from .optimize import onnx_optimize
from .profile import torch_model_profile_via_thop, torch_model_profile_via_calflops
from .testfile import get_testfile

zoo/utils/profile.py

0 → 100644
+39 −0
Original line number Diff line number Diff line
import torch
from ditk import logging
from thop import profile, clever_format


def torch_model_profile_via_thop(model, input_):
    with torch.no_grad():
        flops, params = profile(model, (input_,))

    s_flops, s_params = clever_format([flops, params], "%.1f")
    logging.info(f'Params: {s_params}, FLOPs: {s_flops}.')

    return flops, params


def torch_model_profile_via_calflops(model, input_):
    from calflops import calculate_flops
    flops, macs, params = calculate_flops(
        model=model,
        input_shape=tuple(input_.shape),
        output_as_string=False,
        print_detailed=False,
        # output_as_string=True,
        # output_precision=4
    )
    s_flops, s_params, s_macs = clever_format([flops, params, macs], "%.1f")
    logging.info(f'Params: {s_params}, FLOPs: {s_flops}, MACs: {s_macs}.')
    return flops, params, macs


if __name__ == '__main__':
    logging.try_init_root(level=logging.INFO)
    from timm import create_model

    # model = create_model('hf-hub:animetimm/swinv2_base_window8_256.dbv4-full', pretrained=False)
    model = create_model('caformer_b36.sail_in22k_ft_in1k_384', pretrained=False)
    dummy_input = torch.randn(1, 3, 448, 448)
    print(torch_model_profile_via_thop(model, dummy_input))
    print(torch_model_profile_via_calflops(model, dummy_input))