Loading zoo/ccip/__main__.py +55 −26 Original line number Diff line number Diff line import glob import os.path import random import re import tempfile from functools import partial from typing import Optional, Tuple Loading @@ -20,7 +19,7 @@ from imgutils.data import load_image from .dataset import TEST_TRANSFORM from .demo import _get_model_from_ckpt from .model import CCIP from .onnx import export_full_model_to_onnx, export_feat_model_to_onnx, export_metrics_model_to_onnx from .onnx import export_feat_model_to_onnx, export_metrics_model_to_onnx from ..utils import GLOBAL_CONTEXT_SETTINGS from ..utils import print_version as _origin_print_version Loading @@ -36,7 +35,6 @@ def cli(): _CHECK_ITEMS = { 'full': export_full_model_to_onnx, 'feat': export_feat_model_to_onnx, 'metrics': export_metrics_model_to_onnx, } Loading Loading @@ -68,8 +66,35 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000): threshold, accuracy_score(labels, predictions) def _sample_safe_threshold(poss, negs, precision: float = 0.98) -> Tuple[float, float]: items = sorted([ *((v, 1) for v in poss), *((v, 0) for v in negs), ], key=lambda x: (-x[0], -x[1])) pos_cnt, neg_cnt = 0, 0 r_threshold, r_precision = None, None for i, (v, label) in enumerate(items): if label == 0: neg_cnt += 1 else: pos_cnt += 1 current_precision = pos_cnt / (pos_cnt + neg_cnt) if r_threshold is None or current_precision >= precision or current_precision > r_precision: if i == len(items) - 1: r_threshold = v else: v_next, _ = items[i + 1] r_threshold = (v + v_next) / 2 r_precision = current_precision return r_threshold, r_precision @torch.no_grad() def get_threshold_for_model(model: CCIP, preprocess, samples: int = 200) -> Tuple[float, float]: def get_threshold_for_model(model: CCIP, preprocess, samples: int = 200, safe_precision: float = 0.98) \ -> Tuple[float, float, float, float]: def _get_sim(x, y): x, y = load_image(x, mode='RGB'), load_image(y, mode='RGB') input_ = torch.stack([preprocess(x), preprocess(y)]) Loading @@ -95,43 +120,43 @@ def get_threshold_for_model(model: CCIP, preprocess, samples: int = 200) -> Tupl same_samples = torch.as_tensor(same_samples) _, _, _, _, threshold, accuracy = _sample_analysis(same_samples, not_same_samples, svm_samples=samples) return threshold, accuracy safe_threshold, safe_prec = _sample_safe_threshold(same_samples, not_same_samples, precision=safe_precision) return threshold, accuracy, safe_threshold, safe_prec @cli.command('onnx_check', help='Check onnx export is okay or not') @click.option('--model', '-m', 'model', type=str, required=True, help='Model to be checked. ', show_default=True) @click.option('--check', '-c', 'check_item', type=click.Choice(list(_CHECK_ITEMS.keys())), default=None, help='Model part to be checked. All parts will be checked when not given', show_default=True) @click.option('--verbose', '-V', 'verbose', is_flag=True, type=bool, default=False, help='Show verbose information.', show_default=True) @click.option('--output_dir', '-O', 'output_dir', type=click.Path(file_okay=False), default=None, help='Output directory of all models.', show_default=True) @click.option('--threshold_samples', '-T', 'threshold_samples', type=int, default=500, help='Batch of samples to find threshold.', show_default=True) def onnx_check(model: str, check_item: Optional[str] = None, verbose: bool = False, def onnx_check(model: str, verbose: bool = False, output_dir: Optional[str] = None, threshold_samples: int = 500): logging.try_init_root(logging.INFO) model, model_name = CCIP(model), model model.eval() if not check_item: check_items = list(_CHECK_ITEMS.keys()) else: check_items = [check_item] logging.info('Finding threshold ...') threshold, accuracy = get_threshold_for_model( threshold_mean, accuracy_mean, threshold_safe, precision_safe = get_threshold_for_model( model, transforms.Compose(TEST_TRANSFORM + model.preprocess), samples=threshold_samples, ) logging.info(f'Threshold: {threshold:.4f}, accuracy: {accuracy * 100.0:.2f}%') logging.info(f'Threshold: {threshold_mean:.4f}, accuracy: {accuracy_mean * 100.0:.2f}%') logging.info(f'Safe threshold: {threshold_safe:.4f}, accuracy: {precision_safe * 100.0:.2f}%') with tempfile.TemporaryDirectory() as td: for item in check_items: click.echo(click.style(f'Try exporting {model_name}-->{item} to onnx ... '), nl=False) onnx_filename = os.path.join(output_dir or td, re.sub(r'\W+', '-', f'{model_name}_{item}') + '.onnx') for item, safe, threshold in [ ('feat', False, threshold_mean), ('metrics', False, threshold_mean), ('metrics', True, threshold_safe), ]: click.echo(click.style(f'Try exporting {model_name}(safe={safe!r})-->{item} to onnx ... '), nl=False) onnx_filename = os.path.join(output_dir or td, f'{model_name}_{"safe_" if safe else ""}{item}.onnx') export_func = _CHECK_ITEMS[item] try: model = CCIP(model_name) # necessary Loading @@ -148,8 +173,8 @@ def onnx_check(model: str, check_item: Optional[str] = None, verbose: bool = Fal MODELS = [ # ('caformer', 'ccip-caformer-2_fp32.ckpt'), # ('caformer', 'ccip-caformer-4_fp32.ckpt'), ('caformer', 'ccip-caformer-2_fp32.ckpt'), ('caformer', 'ccip-caformer-4_fp32.ckpt'), ('caformer', 'ccip-caformer-5_fp32.ckpt'), ] Loading @@ -163,25 +188,29 @@ MODELS = [ @click.option('--threshold_samples', '-T', 'threshold_samples', type=int, default=500, help='Batch of samples to find threshold.', show_default=True) def export(output_dir: str, verbose: bool = False, threshold_samples: int = 500): check_items = list(_CHECK_ITEMS.keys()) for model_name, ckpt_name in MODELS: ckpt_file = hf_hub_download('deepghs/ccip', ckpt_name, repo_type='model') model, preprocess = _get_model_from_ckpt(model_name, ckpt_file, device='cpu', fp16=False) ckpt_body, _ = os.path.splitext(ckpt_name) logging.info(f'Finding threshold for {ckpt_name!r} ...') threshold, accuracy = get_threshold_for_model( threshold_mean, accuracy_mean, threshold_safe, precision_safe = get_threshold_for_model( model, transforms.Compose(TEST_TRANSFORM + model.preprocess), samples=threshold_samples, ) logging.info(f'Threshold for {ckpt_file!r}: {threshold:.4f}, accuracy: {accuracy * 100.0:.2f}%') logging.info(f'Threshold for {ckpt_file!r}: {threshold_mean:.4f}, accuracy: {accuracy_mean * 100.0:.2f}%') logging.info(f'Safe threshold for {ckpt_file!r}: {threshold_safe:.4f}, accuracy: {precision_safe * 100.0:.2f}%') with tempfile.TemporaryDirectory() as td: for item in check_items: click.echo(click.style(f'Try exporting {ckpt_body!r}({model_name})-->{item} to onnx ... '), nl=False) onnx_filename = os.path.join(output_dir or td, f'{ckpt_body}_{item}.onnx') for item, safe, threshold in [ ('feat', False, threshold_mean), ('metrics', False, threshold_mean), ('metrics', True, threshold_safe), ]: click.echo(click.style(f'Try exporting {ckpt_body!r}({model_name}, ' f'safe={safe!r})-->{item} to onnx ... '), nl=False) onnx_filename = os.path.join(output_dir or td, f'{ckpt_body}_{"safe_" if safe else ""}{item}.onnx') export_func = _CHECK_ITEMS[item] try: model, preprocess = _get_model_from_ckpt(model_name, ckpt_file, device='cpu', fp16=False) Loading Loading
zoo/ccip/__main__.py +55 −26 Original line number Diff line number Diff line import glob import os.path import random import re import tempfile from functools import partial from typing import Optional, Tuple Loading @@ -20,7 +19,7 @@ from imgutils.data import load_image from .dataset import TEST_TRANSFORM from .demo import _get_model_from_ckpt from .model import CCIP from .onnx import export_full_model_to_onnx, export_feat_model_to_onnx, export_metrics_model_to_onnx from .onnx import export_feat_model_to_onnx, export_metrics_model_to_onnx from ..utils import GLOBAL_CONTEXT_SETTINGS from ..utils import print_version as _origin_print_version Loading @@ -36,7 +35,6 @@ def cli(): _CHECK_ITEMS = { 'full': export_full_model_to_onnx, 'feat': export_feat_model_to_onnx, 'metrics': export_metrics_model_to_onnx, } Loading Loading @@ -68,8 +66,35 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000): threshold, accuracy_score(labels, predictions) def _sample_safe_threshold(poss, negs, precision: float = 0.98) -> Tuple[float, float]: items = sorted([ *((v, 1) for v in poss), *((v, 0) for v in negs), ], key=lambda x: (-x[0], -x[1])) pos_cnt, neg_cnt = 0, 0 r_threshold, r_precision = None, None for i, (v, label) in enumerate(items): if label == 0: neg_cnt += 1 else: pos_cnt += 1 current_precision = pos_cnt / (pos_cnt + neg_cnt) if r_threshold is None or current_precision >= precision or current_precision > r_precision: if i == len(items) - 1: r_threshold = v else: v_next, _ = items[i + 1] r_threshold = (v + v_next) / 2 r_precision = current_precision return r_threshold, r_precision @torch.no_grad() def get_threshold_for_model(model: CCIP, preprocess, samples: int = 200) -> Tuple[float, float]: def get_threshold_for_model(model: CCIP, preprocess, samples: int = 200, safe_precision: float = 0.98) \ -> Tuple[float, float, float, float]: def _get_sim(x, y): x, y = load_image(x, mode='RGB'), load_image(y, mode='RGB') input_ = torch.stack([preprocess(x), preprocess(y)]) Loading @@ -95,43 +120,43 @@ def get_threshold_for_model(model: CCIP, preprocess, samples: int = 200) -> Tupl same_samples = torch.as_tensor(same_samples) _, _, _, _, threshold, accuracy = _sample_analysis(same_samples, not_same_samples, svm_samples=samples) return threshold, accuracy safe_threshold, safe_prec = _sample_safe_threshold(same_samples, not_same_samples, precision=safe_precision) return threshold, accuracy, safe_threshold, safe_prec @cli.command('onnx_check', help='Check onnx export is okay or not') @click.option('--model', '-m', 'model', type=str, required=True, help='Model to be checked. ', show_default=True) @click.option('--check', '-c', 'check_item', type=click.Choice(list(_CHECK_ITEMS.keys())), default=None, help='Model part to be checked. All parts will be checked when not given', show_default=True) @click.option('--verbose', '-V', 'verbose', is_flag=True, type=bool, default=False, help='Show verbose information.', show_default=True) @click.option('--output_dir', '-O', 'output_dir', type=click.Path(file_okay=False), default=None, help='Output directory of all models.', show_default=True) @click.option('--threshold_samples', '-T', 'threshold_samples', type=int, default=500, help='Batch of samples to find threshold.', show_default=True) def onnx_check(model: str, check_item: Optional[str] = None, verbose: bool = False, def onnx_check(model: str, verbose: bool = False, output_dir: Optional[str] = None, threshold_samples: int = 500): logging.try_init_root(logging.INFO) model, model_name = CCIP(model), model model.eval() if not check_item: check_items = list(_CHECK_ITEMS.keys()) else: check_items = [check_item] logging.info('Finding threshold ...') threshold, accuracy = get_threshold_for_model( threshold_mean, accuracy_mean, threshold_safe, precision_safe = get_threshold_for_model( model, transforms.Compose(TEST_TRANSFORM + model.preprocess), samples=threshold_samples, ) logging.info(f'Threshold: {threshold:.4f}, accuracy: {accuracy * 100.0:.2f}%') logging.info(f'Threshold: {threshold_mean:.4f}, accuracy: {accuracy_mean * 100.0:.2f}%') logging.info(f'Safe threshold: {threshold_safe:.4f}, accuracy: {precision_safe * 100.0:.2f}%') with tempfile.TemporaryDirectory() as td: for item in check_items: click.echo(click.style(f'Try exporting {model_name}-->{item} to onnx ... '), nl=False) onnx_filename = os.path.join(output_dir or td, re.sub(r'\W+', '-', f'{model_name}_{item}') + '.onnx') for item, safe, threshold in [ ('feat', False, threshold_mean), ('metrics', False, threshold_mean), ('metrics', True, threshold_safe), ]: click.echo(click.style(f'Try exporting {model_name}(safe={safe!r})-->{item} to onnx ... '), nl=False) onnx_filename = os.path.join(output_dir or td, f'{model_name}_{"safe_" if safe else ""}{item}.onnx') export_func = _CHECK_ITEMS[item] try: model = CCIP(model_name) # necessary Loading @@ -148,8 +173,8 @@ def onnx_check(model: str, check_item: Optional[str] = None, verbose: bool = Fal MODELS = [ # ('caformer', 'ccip-caformer-2_fp32.ckpt'), # ('caformer', 'ccip-caformer-4_fp32.ckpt'), ('caformer', 'ccip-caformer-2_fp32.ckpt'), ('caformer', 'ccip-caformer-4_fp32.ckpt'), ('caformer', 'ccip-caformer-5_fp32.ckpt'), ] Loading @@ -163,25 +188,29 @@ MODELS = [ @click.option('--threshold_samples', '-T', 'threshold_samples', type=int, default=500, help='Batch of samples to find threshold.', show_default=True) def export(output_dir: str, verbose: bool = False, threshold_samples: int = 500): check_items = list(_CHECK_ITEMS.keys()) for model_name, ckpt_name in MODELS: ckpt_file = hf_hub_download('deepghs/ccip', ckpt_name, repo_type='model') model, preprocess = _get_model_from_ckpt(model_name, ckpt_file, device='cpu', fp16=False) ckpt_body, _ = os.path.splitext(ckpt_name) logging.info(f'Finding threshold for {ckpt_name!r} ...') threshold, accuracy = get_threshold_for_model( threshold_mean, accuracy_mean, threshold_safe, precision_safe = get_threshold_for_model( model, transforms.Compose(TEST_TRANSFORM + model.preprocess), samples=threshold_samples, ) logging.info(f'Threshold for {ckpt_file!r}: {threshold:.4f}, accuracy: {accuracy * 100.0:.2f}%') logging.info(f'Threshold for {ckpt_file!r}: {threshold_mean:.4f}, accuracy: {accuracy_mean * 100.0:.2f}%') logging.info(f'Safe threshold for {ckpt_file!r}: {threshold_safe:.4f}, accuracy: {precision_safe * 100.0:.2f}%') with tempfile.TemporaryDirectory() as td: for item in check_items: click.echo(click.style(f'Try exporting {ckpt_body!r}({model_name})-->{item} to onnx ... '), nl=False) onnx_filename = os.path.join(output_dir or td, f'{ckpt_body}_{item}.onnx') for item, safe, threshold in [ ('feat', False, threshold_mean), ('metrics', False, threshold_mean), ('metrics', True, threshold_safe), ]: click.echo(click.style(f'Try exporting {ckpt_body!r}({model_name}, ' f'safe={safe!r})-->{item} to onnx ... '), nl=False) onnx_filename = os.path.join(output_dir or td, f'{ckpt_body}_{"safe_" if safe else ""}{item}.onnx') export_func = _CHECK_ITEMS[item] try: model, preprocess = _get_model_from_ckpt(model_name, ckpt_file, device='cpu', fp16=False) Loading