Commit 03aabb9b authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): refactor model

parent 10f267d1
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
hbutils>=0.8.2
hbutils>=0.9.0
pillow
numpy
scikit-learn
+54 −0
Original line number Diff line number Diff line
import os.path
import re
import tempfile
from functools import partial
from typing import Optional

import click
from hbutils.testing import disable_output

from .model import CCIP
from .onnx import export_full_model_to_onnx, export_feat_model_to_onnx, export_metrics_model_to_onnx
from ..utils import GLOBAL_CONTEXT_SETTINGS
from ..utils import print_version as _origin_print_version

@@ -14,3 +21,50 @@ print_version = partial(_origin_print_version, 'zoo.ccip')
              help="Utils with pixiv resources.")
def cli():
    pass  # pragma: no cover


_CHECK_ITEMS = {
    'full': export_full_model_to_onnx,
    'feat': export_feat_model_to_onnx,
    'metrics': export_metrics_model_to_onnx,
}


@cli.command('onnx_check', help='Check onnx export is okay or not')
@click.option('--model', '-m', 'model', type=str, required=True,
              help='Model to be checked. ', show_default=True)
@click.option('--check', '-c', 'check_item', type=click.Choice(list(_CHECK_ITEMS.keys())), default=None,
              help='Model part to be checked. All parts will be checked when not given', show_default=True)
@click.option('--verbose', '-V', 'verbose', is_flag=True, type=bool, default=False,
              help='Show verbose information.', show_default=True)
@click.option('--output_dir', '-O', 'output_dir', type=click.Path(file_okay=False), default=None,
              help='Output directory of all models.', show_default=True)
def onnx_check(model: str, check_item: Optional[str] = None, verbose: bool = False,
               output_dir: Optional[str] = None):
    model, model_name = CCIP(model), model
    if not check_item:
        check_items = list(_CHECK_ITEMS.keys())
    else:
        check_items = [check_item]

    with tempfile.TemporaryDirectory() as td:
        for item in check_items:
            click.echo(click.style(f'Try exporting {model_name}-->{item} to onnx ... '), nl=False)
            onnx_filename = os.path.join(output_dir or td, re.sub(r'\W+', '-', f'{model_name}_{item}') + '.onnx')
            export_func = _CHECK_ITEMS[item]
            try:
                model = CCIP(model_name)  # necessary
                if verbose:
                    export_func(model, onnx_filename, verbose=verbose)
                else:
                    with disable_output():
                        export_func(model, onnx_filename, verbose=verbose)
            except:
                click.echo(click.style('FAILED', fg='red'), nl=True)
                raise
            else:
                click.echo(click.style('OK', fg='green'), nl=True)


if __name__ == '__main__':
    cli()
+24 −23
Original line number Diff line number Diff line
@@ -7,16 +7,15 @@ from zoo.utils import get_testfile
from .backbone import get_backbone


class DiffMethod(nn.Module):
class CCIPBatchMetrics(nn.Module):
    def __init__(self):
        nn.Module.__init__(self)
        self.sim = nn.CosineSimilarity(dim=-1)
        self.fc = nn.Linear(1, 2)

    def forward(self, x):
        """
        Bx1 --> Bx2
        """
        x = self.fc(x)
    def forward(self, x):  # x: BxN
        x = self.sim(x, x.unsqueeze(1))
        x = self.fc(x.unsqueeze(-1))
        return x


@@ -34,30 +33,32 @@ class CCIPFeature(torch.nn.Module):
class CCIP(torch.nn.Module):
    def __init__(self, name: str = "clip/ViT-B/32"):
        torch.nn.Module.__init__(self)
        self.backbone = CCIPFeature(name)
        self.diff = DiffMethod()
        self.cos_sim = torch.nn.CosineSimilarity(dim=-1)
        self.feature = CCIPFeature(name)
        self.metrics = CCIPBatchMetrics()

    @property
    def preprocess(self):
        return self.backbone.preprocess
        return self.feature.preprocess

    def forward(self, x, y):
        x = self.backbone(x)
        y = self.backbone(y)
        dis = self.cos_sim(x, y)
        return self.diff(dis)
    def forward(self, x):
        # x: BxCxHxW
        x = self.feature(x)  # BxF
        x = self.metrics(x)  # BxBx2
        return x


if __name__ == '__main__':
    image1 = Image.open(get_testfile('6124220.jpg'))
    image2 = Image.open(get_testfile('6125785.jpg'))
    image_files = [
        get_testfile('6124220.jpg'),
        get_testfile('6125785.jpg'),
        get_testfile('6125901.jpg'),
    ]

    model = CCIP()
    d1 = model.preprocess(image1).unsqueeze(0)
    d2 = model.preprocess(image2).unsqueeze(0)

    print(d1.shape, d1.dtype)
    print(d2.shape, d2.dtype)
    data = torch.stack([
        model.preprocess(Image.open(img))
        for img in image_files
    ])
    print(data.dtype, data.shape)

    print(F.softmax(model.forward(d1, d2), dim=-1))
    print(F.softmax(model.forward(data), dim=-1))

zoo/ccip/onnx.py

0 → 100644
+110 −0
Original line number Diff line number Diff line
import os
from tempfile import TemporaryDirectory

import onnx
import torch
from PIL import Image
from torch import nn

from .model import CCIP
from ..utils import get_testfile, onnx_optimize


class ModelWithSoftMax(nn.Module):
    def __init__(self, model):
        nn.Module.__init__(self)
        self.model = model

    def forward(self, x):
        x = self.model(x)
        x = torch.softmax(x, dim=-1)
        return x


def get_batch_images(preprocess) -> torch.Tensor:
    image_files = [
        get_testfile('6124220.jpg'),
        get_testfile('6125785.jpg'),
        get_testfile('6125901.jpg'),
    ]

    return torch.stack([
        preprocess(Image.open(img))
        for img in image_files
    ])


def _onnx_export(model, example_input, onnx_filename, opset_version: int = 14, verbose: bool = True,
                 no_optimize: bool = False, dynamic_axes=None):
    model = model.float()
    if torch.cuda.is_available():
        example_input = example_input.cuda()
        model = model.cuda()
    else:
        example_input = example_input.cpu()
        model = model.cpu()

    with torch.no_grad(), TemporaryDirectory() as td:
        onnx_model_file = os.path.join(td, 'model.onnx')
        torch.onnx.export(
            model,
            example_input,
            onnx_model_file,
            verbose=verbose,
            input_names=["input"],
            output_names=["output"],

            opset_version=opset_version,
            dynamic_axes=dynamic_axes or {},
        )

        model = onnx.load(onnx_model_file)
        if not no_optimize:
            model = onnx_optimize(model)

        output_model_dir, _ = os.path.split(onnx_filename)
        if output_model_dir:
            os.makedirs(output_model_dir, exist_ok=True)
        onnx.save(model, onnx_filename)


def export_full_model_to_onnx(model: CCIP, onnx_filename, opset_version: int = 14, verbose: bool = True,
                              no_optimize: bool = False):
    example_input = get_batch_images(model.preprocess)
    return _onnx_export(
        ModelWithSoftMax(model), example_input,
        onnx_filename, opset_version, verbose, no_optimize,
        dynamic_axes={
            "input": {0: "batch"},
            "output": {0: "batch", 1: "batch"},
        }
    )


def export_feat_model_to_onnx(model: CCIP, onnx_filename, opset_version: int = 14, verbose: bool = True,
                              no_optimize: bool = False):
    example_input = get_batch_images(model.preprocess)
    return _onnx_export(
        model.feature, example_input,
        onnx_filename, opset_version, verbose, no_optimize,
        dynamic_axes={
            "input": {0: "batch"},
            "output": {0: "batch"},
        }
    )


def export_metrics_model_to_onnx(model: CCIP, onnx_filename, opset_version: int = 14, verbose: bool = True,
                                 no_optimize: bool = False):
    origin = get_batch_images(model.preprocess)
    with torch.no_grad():
        example_input = model.feature(origin)

    return _onnx_export(
        ModelWithSoftMax(model.metrics), example_input,
        onnx_filename, opset_version, verbose, no_optimize,
        dynamic_axes={
            "input": {0: "batch"},
            "output": {0: "batch", 1: "batch"},
        }
    )
+8 −14
Original line number Diff line number Diff line
@@ -134,14 +134,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

            ix = torch.arange(0, char_ids.shape[0])
            mask = ix >= ix.reshape(-1, 1)  # BxB, remove duplicated
            m_labels = (char_ids == char_ids.reshape(-1, 1))  # BxB
            logits = model(inputs)  # BxBx2
            outputs = logits[mask]  # Nx2
            labels = (char_ids == char_ids.reshape(-1, 1))[mask]  # N
            labels = labels.type(torch.long).to(accelerator.device)  # N

            features = model.backbone(inputs)  # BxF
            m_sims = torch.nn.CosineSimilarity(dim=-1)(features, features.unsqueeze(1))  # BxB
            sims = m_sims[mask].to(accelerator.device)
            labels = m_labels[mask].type(torch.long).to(accelerator.device)

            outputs = model.diff(sims.reshape(-1, 1))
            preds = torch.argmax(outputs, dim=1)
            train_correct += (preds == labels).sum().item()
            train_fp += (preds[labels == 0] == 1).sum().item()
@@ -182,14 +179,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

                    ix = torch.arange(0, char_ids.shape[0])
                    mask = ix >= ix.reshape(-1, 1)  # BxB, remove duplicated
                    m_labels = (char_ids == char_ids.reshape(-1, 1))  # BxB

                    features = model.backbone(inputs)  # BxF
                    m_sims = torch.nn.CosineSimilarity(dim=-1)(features, features.unsqueeze(1))  # BxB
                    sims = m_sims[mask].to(accelerator.device)
                    labels = m_labels[mask].type(torch.long).to(accelerator.device)
                    logits = model(inputs)  # BxBx2
                    outputs = logits[mask]  # Nx2
                    labels = (char_ids == char_ids.reshape(-1, 1))[mask]  # N
                    labels = labels.type(torch.long).to(accelerator.device)  # N

                    outputs = model.diff(sims.reshape(-1, 1))
                    preds = torch.argmax(outputs, dim=1)
                    test_correct += (preds == labels).sum().item()
                    test_fp += (preds[labels == 0] == 1).sum().item()