Commit cc738c10 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update onnx export

parent e88c69d8
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -18,9 +18,10 @@ jobs:
        model-name:
          #          - 'lpips'
          #          - 'monochrome'
          - 'person_detect'
          #          - 'person_detect'
          #          - 'face_detect'
          #          - 'manbits_detect'
          - 'ccip'

    steps:
      - name: Checkout code
+133 −4
Original line number Diff line number Diff line
import glob
import os.path
import random
import re
import tempfile
from functools import partial
from typing import Optional
from typing import Optional, Tuple

import click
import torch
from ditk import logging
from hbutils.testing import disable_output
from huggingface_hub import hf_hub_download
from sklearn import svm
from sklearn.metrics import accuracy_score
from torchvision import transforms
from tqdm.auto import tqdm

from imgutils.data import load_image
from .dataset import TEST_TRANSFORM
from .demo import _get_model_from_ckpt
from .model import CCIP
from .onnx import export_full_model_to_onnx, export_feat_model_to_onnx, export_metrics_model_to_onnx
from ..utils import GLOBAL_CONTEXT_SETTINGS
@@ -30,6 +42,62 @@ _CHECK_ITEMS = {
}


def _sample_analysis(poss, negs, svm_samples: int = 10000):
    poss_cnt, negs_cnt = poss.shape[0], negs.shape[0]
    total = poss_cnt + negs_cnt
    if total > svm_samples:
        s_poss = poss[random.sample(range(poss_cnt), k=int(round(poss_cnt * svm_samples / total)))]
        s_negs = negs[random.sample(range(negs_cnt), k=int(round(negs_cnt * svm_samples / total)))]
    else:
        s_poss, s_negs = poss, negs

    s_poss, s_negs = s_poss.cpu(), s_negs.cpu()
    features = torch.cat([s_poss, s_negs]).detach().numpy()
    labels = torch.cat([torch.ones_like(s_poss), -torch.ones_like(s_negs)]).detach().numpy()

    model = svm.SVC(kernel='linear')  # 线性核
    model.fit(features.reshape(-1, 1), labels)
    predictions = model.predict(features.reshape(-1, 1))

    coef = model.coef_.reshape(-1)[0].tolist()
    inter = model.intercept_.reshape(-1)[0].tolist()
    threshold = -inter / coef

    return poss.mean().item(), poss.std().item(), \
        negs.mean().item(), negs.std().item(), \
        threshold, accuracy_score(labels, predictions)


@torch.no_grad()
def get_threshold_for_model(model: CCIP, preprocess, samples: int = 200) -> Tuple[float, float]:
    def _get_sim(x, y):
        x, y = load_image(x, mode='RGB'), load_image(y, mode='RGB')
        input_ = torch.stack([preprocess(x), preprocess(y)])
        return model(input_)[0][1]

    dataset_dir = 'test/testfile/dataset/images_xtiny_v0/'
    all_images = glob.glob(os.path.join(dataset_dir, '*', '*', '*.jpg'))
    all_chs = sorted(set([os.path.dirname(img) for img in all_images]))

    not_same_samples = []
    for _ in tqdm(range(samples)):
        x_ch, y_ch = random.sample(all_chs, k=2)
        x_img = random.choice(glob.glob(os.path.join(x_ch, '*.jpg')))
        y_img = random.choice(glob.glob(os.path.join(y_ch, '*.jpg')))
        not_same_samples.append(_get_sim(x_img, y_img))
    not_same_samples = torch.as_tensor(not_same_samples)

    same_samples = []
    for _ in tqdm(range(samples)):
        ch = random.choice(all_chs)
        x_img, y_img = random.sample(glob.glob(os.path.join(ch, '*.jpg')), k=2)
        same_samples.append(_get_sim(x_img, y_img))
    same_samples = torch.as_tensor(same_samples)

    _, _, _, _, threshold, accuracy = _sample_analysis(same_samples, not_same_samples, svm_samples=samples)
    return threshold, accuracy


@cli.command('onnx_check', help='Check onnx export is okay or not')
@click.option('--model', '-m', 'model', type=str, required=True,
              help='Model to be checked. ', show_default=True)
@@ -39,14 +107,27 @@ _CHECK_ITEMS = {
              help='Show verbose information.', show_default=True)
@click.option('--output_dir', '-O', 'output_dir', type=click.Path(file_okay=False), default=None,
              help='Output directory of all models.', show_default=True)
@click.option('--threshold_samples', '-T', 'threshold_samples', type=int, default=500,
              help='Batch of samples to find threshold.', show_default=True)
def onnx_check(model: str, check_item: Optional[str] = None, verbose: bool = False,
               output_dir: Optional[str] = None):
               output_dir: Optional[str] = None, threshold_samples: int = 500):
    logging.try_init_root(logging.INFO)

    model, model_name = CCIP(model), model
    model.eval()
    if not check_item:
        check_items = list(_CHECK_ITEMS.keys())
    else:
        check_items = [check_item]

    logging.info('Finding threshold ...')
    threshold, accuracy = get_threshold_for_model(
        model,
        transforms.Compose(TEST_TRANSFORM + model.preprocess),
        samples=threshold_samples,
    )
    logging.info(f'Threshold: {threshold:.4f}, accuracy: {accuracy * 100.0:.2f}%')

    with tempfile.TemporaryDirectory() as td:
        for item in check_items:
            click.echo(click.style(f'Try exporting {model_name}-->{item} to onnx ... '), nl=False)
@@ -55,10 +136,58 @@ def onnx_check(model: str, check_item: Optional[str] = None, verbose: bool = Fal
            try:
                model = CCIP(model_name)  # necessary
                if verbose:
                    export_func(model, onnx_filename, verbose=verbose)
                    export_func(model, threshold, onnx_filename, verbose=verbose)
                else:
                    with disable_output():
                        export_func(model, threshold, onnx_filename, verbose=verbose)
            except:
                click.echo(click.style('FAILED', fg='red'), nl=True)
                raise
            else:
                click.echo(click.style('OK', fg='green'), nl=True)


MODELS = [
    ('caformer', 'ccip-caformer-2_fp32.ckpt'),
]


@cli.command('export', help='Export all models as onnx.',
             context_settings={**GLOBAL_CONTEXT_SETTINGS})
@click.option('--output_dir', '-O', 'output_dir', type=click.Path(file_okay=False), required=True,
              help='Output directory of all models.', show_default=True)
@click.option('--verbose', '-V', 'verbose', is_flag=True, type=bool, default=False,
              help='Show verbose information.', show_default=True)
@click.option('--threshold_samples', '-T', 'threshold_samples', type=int, default=500,
              help='Batch of samples to find threshold.', show_default=True)
def export(output_dir: str, verbose: bool = False, threshold_samples: int = 500):
    check_items = list(_CHECK_ITEMS.keys())

    for model_name, ckpt_name in MODELS:
        ckpt_file = hf_hub_download('deepghs/ccip', ckpt_name, repo_type='model')
        model, preprocess = _get_model_from_ckpt(model_name, ckpt_file, device='cpu', fp16=False)
        ckpt_body, _ = os.path.splitext(ckpt_name)

        logging.info(f'Finding threshold for {ckpt_name!r} ...')
        threshold, accuracy = get_threshold_for_model(
            model,
            transforms.Compose(TEST_TRANSFORM + model.preprocess),
            samples=threshold_samples,
        )
        logging.info(f'Threshold for {ckpt_file!r}: {threshold:.4f}, accuracy: {accuracy * 100.0:.2f}%')

        with tempfile.TemporaryDirectory() as td:
            for item in check_items:
                click.echo(click.style(f'Try exporting {ckpt_body!r}({model_name})-->{item} to onnx ... '), nl=False)
                onnx_filename = os.path.join(output_dir or td, f'{ckpt_body}_{item}.onnx')
                export_func = _CHECK_ITEMS[item]
                try:
                    model = CCIP(model_name)  # necessary
                    if verbose:
                        export_func(model, threshold, onnx_filename, verbose=verbose)
                    else:
                        with disable_output():
                        export_func(model, onnx_filename, verbose=verbose)
                            export_func(model, threshold, onnx_filename, verbose=verbose)
                except:
                    click.echo(click.style('FAILED', fg='red'), nl=True)
                    raise
+24 −13
Original line number Diff line number Diff line
@@ -13,24 +13,35 @@ def _load_remote_ckpt(remote_ckpt):
    return hf_hub_download('deepghs/ccip', remote_ckpt, repo_type='model')


class Infer:
    def __init__(self, args, device=None):
        self.args = args
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
def _get_model_from_ckpt(model_name, ckpt, device, fp16: bool):
    model = CCIP(model_name).to(device)
    model.eval()
    if fp16:
        model = model.half()

        self.model = CCIP(args.model_name).to(device)
        self.model.eval()
        if self.args.fp16:
            self.model = self.model.half()

        state = torch.load(args.ckpt or _load_remote_ckpt(args.remote_ckpt), map_location='cpu')
    state = torch.load(ckpt, map_location='cpu')
    try:
            self.model.load_state_dict(state)
        model.load_state_dict(state)
    except:
        len_p = len('module._orig_mod.')
            self.model.load_state_dict({k[len_p:]: v for k, v in state.items()})
        model.load_state_dict({k[len_p:]: v for k, v in state.items()})

    preprocess = transforms.Compose(TEST_TRANSFORM + model.preprocess)

    return model, preprocess


class Infer:
    def __init__(self, args, device=None):
        self.args = args
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

        self.img_transform = transforms.Compose(TEST_TRANSFORM + self.model.preprocess)
        self.model, self.img_transform = _get_model_from_ckpt(
            model_name=args.model_name,
            ckpt=args.ckpt or _load_remote_ckpt(args.remote_ckpt),
            device=self.device,
            fp16=args.fp16
        )

    def load_img(self, path):
        image = load_image(path, mode='RGB')
+13 −2
Original line number Diff line number Diff line
@@ -13,13 +13,13 @@ class CCIPBatchMetrics(nn.Module):
        # self.sim = nn.CosineSimilarity(dim=-1)

    def forward(self, image_features):  # x: BxN

        # normalized features
        image_features = image_features / image_features.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ image_features.t()
        if self.training:
            logits_per_image = logits_per_image - torch.diag_embed(torch.diag(logits_per_image))

        return logits_per_image
@@ -53,6 +53,17 @@ class CCIP(nn.Module):
        return x


class LogitToConfidence(nn.Module):
    def __init__(self, threshold):
        nn.Module.__init__(self)
        self.register_buffer('threshold', torch.tensor(threshold))
        self.threshold: torch.Tensor

    def forward(self, x):
        ex = (x - self.threshold)
        return torch.exp(ex) / (torch.exp(ex) + 1.0)


if __name__ == '__main__':
    # image_files = [
    #     get_testfile('6124220.jpg'),
+17 −14
Original line number Diff line number Diff line
@@ -5,20 +5,21 @@ import onnx
import torch
from PIL import Image
from torch import nn
from torchvision import transforms

from .model import CCIP
from .dataset import TEST_TRANSFORM
from .model import CCIP, LogitToConfidence
from ..utils import get_testfile, onnx_optimize


class ModelWithSoftMax(nn.Module):
    def __init__(self, model):
class ModelWithConfidence(nn.Module):
    def __init__(self, model, threshold):
        nn.Module.__init__(self)
        self.model = model
        self.logit_to_conf = LogitToConfidence(threshold)

    def forward(self, x):
        x = self.model(x)
        x = torch.softmax(x, dim=-1)
        return x
        return self.logit_to_conf(self.model(x))


def get_batch_images(preprocess) -> torch.Tensor:
@@ -28,6 +29,7 @@ def get_batch_images(preprocess) -> torch.Tensor:
        get_testfile('6125901.jpg'),
    ]

    preprocess = transforms.Compose(TEST_TRANSFORM + preprocess)
    return torch.stack([
        preprocess(Image.open(img))
        for img in image_files
@@ -68,11 +70,11 @@ def _onnx_export(model, example_input, onnx_filename, opset_version: int = 14, v
        onnx.save(model, onnx_filename)


def export_full_model_to_onnx(model: CCIP, onnx_filename, opset_version: int = 14, verbose: bool = True,
                              no_optimize: bool = False):
def export_full_model_to_onnx(model: CCIP, threshold: float, onnx_filename, opset_version: int = 14,
                              verbose: bool = True, no_optimize: bool = False):
    example_input = get_batch_images(model.preprocess)
    return _onnx_export(
        ModelWithSoftMax(model), example_input,
        ModelWithConfidence(model, threshold), example_input,
        onnx_filename, opset_version, verbose, no_optimize,
        dynamic_axes={
            "input": {0: "batch"},
@@ -81,8 +83,9 @@ def export_full_model_to_onnx(model: CCIP, onnx_filename, opset_version: int = 1
    )


def export_feat_model_to_onnx(model: CCIP, onnx_filename, opset_version: int = 14, verbose: bool = True,
                              no_optimize: bool = False):
def export_feat_model_to_onnx(model: CCIP, threshold: float, onnx_filename, opset_version: int = 14,
                              verbose: bool = True, no_optimize: bool = False):
    _ = threshold
    example_input = get_batch_images(model.preprocess)
    return _onnx_export(
        model.feature, example_input,
@@ -94,14 +97,14 @@ def export_feat_model_to_onnx(model: CCIP, onnx_filename, opset_version: int = 1
    )


def export_metrics_model_to_onnx(model: CCIP, onnx_filename, opset_version: int = 14, verbose: bool = True,
                                 no_optimize: bool = False):
def export_metrics_model_to_onnx(model: CCIP, threshold: float, onnx_filename, opset_version: int = 14,
                                 verbose: bool = True, no_optimize: bool = False):
    origin = get_batch_images(model.preprocess)
    with torch.no_grad():
        example_input = model.feature(origin)

    return _onnx_export(
        ModelWithSoftMax(model.metrics), example_input,
        ModelWithConfidence(model.metrics, threshold), example_input,
        onnx_filename, opset_version, verbose, no_optimize,
        dynamic_axes={
            "input": {0: "batch"},