Commit 8e34a0fc authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): use better way to export

parent 4ffea064
Loading
Loading
Loading
Loading
+55 −26
Original line number Diff line number Diff line
import glob
import os.path
import random
import re
import tempfile
from functools import partial
from typing import Optional, Tuple
@@ -20,7 +19,7 @@ from imgutils.data import load_image
from .dataset import TEST_TRANSFORM
from .demo import _get_model_from_ckpt
from .model import CCIP
from .onnx import export_full_model_to_onnx, export_feat_model_to_onnx, export_metrics_model_to_onnx
from .onnx import export_feat_model_to_onnx, export_metrics_model_to_onnx
from ..utils import GLOBAL_CONTEXT_SETTINGS
from ..utils import print_version as _origin_print_version

@@ -36,7 +35,6 @@ def cli():


_CHECK_ITEMS = {
    'full': export_full_model_to_onnx,
    'feat': export_feat_model_to_onnx,
    'metrics': export_metrics_model_to_onnx,
}
@@ -68,8 +66,35 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000):
        threshold, accuracy_score(labels, predictions)


def _sample_safe_threshold(poss, negs, precision: float = 0.98) -> Tuple[float, float]:
    items = sorted([
        *((v, 1) for v in poss),
        *((v, 0) for v in negs),
    ], key=lambda x: (-x[0], -x[1]))

    pos_cnt, neg_cnt = 0, 0
    r_threshold, r_precision = None, None
    for i, (v, label) in enumerate(items):
        if label == 0:
            neg_cnt += 1
        else:
            pos_cnt += 1

        current_precision = pos_cnt / (pos_cnt + neg_cnt)
        if r_threshold is None or current_precision >= precision or current_precision > r_precision:
            if i == len(items) - 1:
                r_threshold = v
            else:
                v_next, _ = items[i + 1]
                r_threshold = (v + v_next) / 2
            r_precision = current_precision

    return r_threshold, r_precision


@torch.no_grad()
def get_threshold_for_model(model: CCIP, preprocess, samples: int = 200) -> Tuple[float, float]:
def get_threshold_for_model(model: CCIP, preprocess, samples: int = 200, safe_precision: float = 0.98) \
        -> Tuple[float, float, float, float]:
    def _get_sim(x, y):
        x, y = load_image(x, mode='RGB'), load_image(y, mode='RGB')
        input_ = torch.stack([preprocess(x), preprocess(y)])
@@ -95,43 +120,43 @@ def get_threshold_for_model(model: CCIP, preprocess, samples: int = 200) -> Tupl
    same_samples = torch.as_tensor(same_samples)

    _, _, _, _, threshold, accuracy = _sample_analysis(same_samples, not_same_samples, svm_samples=samples)
    return threshold, accuracy
    safe_threshold, safe_prec = _sample_safe_threshold(same_samples, not_same_samples, precision=safe_precision)
    return threshold, accuracy, safe_threshold, safe_prec


@cli.command('onnx_check', help='Check onnx export is okay or not')
@click.option('--model', '-m', 'model', type=str, required=True,
              help='Model to be checked. ', show_default=True)
@click.option('--check', '-c', 'check_item', type=click.Choice(list(_CHECK_ITEMS.keys())), default=None,
              help='Model part to be checked. All parts will be checked when not given', show_default=True)
@click.option('--verbose', '-V', 'verbose', is_flag=True, type=bool, default=False,
              help='Show verbose information.', show_default=True)
@click.option('--output_dir', '-O', 'output_dir', type=click.Path(file_okay=False), default=None,
              help='Output directory of all models.', show_default=True)
@click.option('--threshold_samples', '-T', 'threshold_samples', type=int, default=500,
              help='Batch of samples to find threshold.', show_default=True)
def onnx_check(model: str, check_item: Optional[str] = None, verbose: bool = False,
def onnx_check(model: str, verbose: bool = False,
               output_dir: Optional[str] = None, threshold_samples: int = 500):
    logging.try_init_root(logging.INFO)

    model, model_name = CCIP(model), model
    model.eval()
    if not check_item:
        check_items = list(_CHECK_ITEMS.keys())
    else:
        check_items = [check_item]

    logging.info('Finding threshold ...')
    threshold, accuracy = get_threshold_for_model(
    threshold_mean, accuracy_mean, threshold_safe, precision_safe = get_threshold_for_model(
        model,
        transforms.Compose(TEST_TRANSFORM + model.preprocess),
        samples=threshold_samples,
    )
    logging.info(f'Threshold: {threshold:.4f}, accuracy: {accuracy * 100.0:.2f}%')
    logging.info(f'Threshold: {threshold_mean:.4f}, accuracy: {accuracy_mean * 100.0:.2f}%')
    logging.info(f'Safe threshold: {threshold_safe:.4f}, accuracy: {precision_safe * 100.0:.2f}%')

    with tempfile.TemporaryDirectory() as td:
        for item in check_items:
            click.echo(click.style(f'Try exporting {model_name}-->{item} to onnx ... '), nl=False)
            onnx_filename = os.path.join(output_dir or td, re.sub(r'\W+', '-', f'{model_name}_{item}') + '.onnx')
        for item, safe, threshold in [
            ('feat', False, threshold_mean),
            ('metrics', False, threshold_mean),
            ('metrics', True, threshold_safe),
        ]:
            click.echo(click.style(f'Try exporting {model_name}(safe={safe!r})-->{item} to onnx ... '), nl=False)
            onnx_filename = os.path.join(output_dir or td, f'{model_name}_{"safe_" if safe else ""}{item}.onnx')
            export_func = _CHECK_ITEMS[item]
            try:
                model = CCIP(model_name)  # necessary
@@ -148,8 +173,8 @@ def onnx_check(model: str, check_item: Optional[str] = None, verbose: bool = Fal


MODELS = [
    # ('caformer', 'ccip-caformer-2_fp32.ckpt'),
    # ('caformer', 'ccip-caformer-4_fp32.ckpt'),
    ('caformer', 'ccip-caformer-2_fp32.ckpt'),
    ('caformer', 'ccip-caformer-4_fp32.ckpt'),
    ('caformer', 'ccip-caformer-5_fp32.ckpt'),
]

@@ -163,25 +188,29 @@ MODELS = [
@click.option('--threshold_samples', '-T', 'threshold_samples', type=int, default=500,
              help='Batch of samples to find threshold.', show_default=True)
def export(output_dir: str, verbose: bool = False, threshold_samples: int = 500):
    check_items = list(_CHECK_ITEMS.keys())

    for model_name, ckpt_name in MODELS:
        ckpt_file = hf_hub_download('deepghs/ccip', ckpt_name, repo_type='model')
        model, preprocess = _get_model_from_ckpt(model_name, ckpt_file, device='cpu', fp16=False)
        ckpt_body, _ = os.path.splitext(ckpt_name)

        logging.info(f'Finding threshold for {ckpt_name!r} ...')
        threshold, accuracy = get_threshold_for_model(
        threshold_mean, accuracy_mean, threshold_safe, precision_safe = get_threshold_for_model(
            model,
            transforms.Compose(TEST_TRANSFORM + model.preprocess),
            samples=threshold_samples,
        )
        logging.info(f'Threshold for {ckpt_file!r}: {threshold:.4f}, accuracy: {accuracy * 100.0:.2f}%')
        logging.info(f'Threshold for {ckpt_file!r}: {threshold_mean:.4f}, accuracy: {accuracy_mean * 100.0:.2f}%')
        logging.info(f'Safe threshold for {ckpt_file!r}: {threshold_safe:.4f}, accuracy: {precision_safe * 100.0:.2f}%')

        with tempfile.TemporaryDirectory() as td:
            for item in check_items:
                click.echo(click.style(f'Try exporting {ckpt_body!r}({model_name})-->{item} to onnx ... '), nl=False)
                onnx_filename = os.path.join(output_dir or td, f'{ckpt_body}_{item}.onnx')
            for item, safe, threshold in [
                ('feat', False, threshold_mean),
                ('metrics', False, threshold_mean),
                ('metrics', True, threshold_safe),
            ]:
                click.echo(click.style(f'Try exporting {ckpt_body!r}({model_name}, '
                                       f'safe={safe!r})-->{item} to onnx ... '), nl=False)
                onnx_filename = os.path.join(output_dir or td, f'{ckpt_body}_{"safe_" if safe else ""}{item}.onnx')
                export_func = _CHECK_ITEMS[item]
                try:
                    model, preprocess = _get_model_from_ckpt(model_name, ckpt_file, device='cpu', fp16=False)