Commit 96ba7f92 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add more

parent fafc9e12
Loading
Loading
Loading
Loading
+51 −2
Original line number Diff line number Diff line
@@ -2,22 +2,29 @@ import json
import os.path

import numpy as np
import onnx
import onnxruntime
import pandas as pd
import torch
from ditk import logging
from hbutils.system import TemporaryDirectory
from hfutils.operate import get_hf_client

from imgutils.data import load_image
from imgutils.preprocess import parse_torchvision_transforms
from zoo.pixai_tagger.tags import load_tags
from zoo.utils import onnx_optimize, get_testfile
from .min_script import EndpointHandler
from .onnx import get_model


def sync(src_repo: str, dst_repo: str):
def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
    hf_client = get_hf_client()
    if not hf_client.repo_exists(repo_id=dst_repo, repo_type='model'):
        hf_client.create_repo(repo_id=dst_repo, repo_type='model', private=True)

    handler = EndpointHandler(repo_id=src_repo)

    with TemporaryDirectory() as upload_dir:
        preprocessor = handler.transform
        preprocessor_file = os.path.join(upload_dir, 'preprocessor.json')
@@ -67,7 +74,49 @@ def sync(src_repo: str, dst_repo: str):
        ])
        df_th.to_csv(os.path.join(upload_dir, 'thresholds.csv'), index=False)

        handler.model
        dummy_image = load_image(get_testfile('6125785.jpg'), mode='RGB', force_background='white')
        dummy_input = handler.transform(dummy_image).unsqueeze(0)
        wrapped_model, (conv_features, _) = get_model(handler.model, dummy_input)
        onnx_filename = os.path.join(upload_dir, 'model.onnx')
        with TemporaryDirectory() as td:
            temp_model_onnx = os.path.join(td, 'model.onnx')
            logging.info(f'Exporting temporary ONNX model to {temp_model_onnx!r} ...')
            torch.onnx.export(
                wrapped_model,
                dummy_input,
                temp_model_onnx,
                input_names=['input'],
                output_names=['embedding', 'output'],
                dynamic_axes={
                    'input': {0: 'batch_size'},
                    'embedding': {0: 'batch_size'},
                    'output': {0: 'batch_size'},
                },
                opset_version=14,
                do_constant_folding=True,
                export_params=True,
                verbose=False,
                custom_opsets=None,
            )

            model = onnx.load(temp_model_onnx)
            if not no_optimize:
                logging.info('Optimizing onnx model ...')
                model = onnx_optimize(model)

            output_model_dir, _ = os.path.split(onnx_filename)
            if output_model_dir:
                os.makedirs(output_model_dir, exist_ok=True)
            logging.info(f'Complete model saving to {onnx_filename!r} ...')
            onnx.save(model, onnx_filename)

            session = onnxruntime.InferenceSession(onnx_filename)
            o_embeddings, = session.run(['embedding'], {'input': dummy_input.numpy()})
            emb_1 = o_embeddings / np.linalg.norm(o_embeddings, axis=-1, keepdims=True)
            emb_2 = conv_features.numpy() / np.linalg.norm(conv_features.numpy(), axis=-1, keepdims=True)
            emb_sims = (emb_1 * emb_2).sum()
            logging.info(f'Similarity of the embeddings is {emb_sims:.5f}.')
            assert emb_sims >= 0.98, f'Similarity of the embeddings is {emb_sims:.5f}, ONNX validation failed.'

        os.system(f'tree {upload_dir!r}')
        input()
+48 −0
Original line number Diff line number Diff line
import logging

import torch
from torch import nn


class ModuleWrapper(nn.Module):
    def __init__(self, base_module: nn.Module, classifier: nn.Module):
        super().__init__()
        self.base_module = base_module
        self.classifier = classifier

        self._output_features = None
        self._register_hook()

    def _register_hook(self):
        def hook_fn(module, input_tensor, output_tensor):
            assert isinstance(input_tensor, tuple) and len(input_tensor) == 1
            input_tensor = input_tensor[0]
            self._output_features = input_tensor

        self.classifier.register_forward_hook(hook_fn)

    def forward(self, x: torch.Tensor):
        preds = self.base_module(x)

        if self._output_features is None:
            raise RuntimeError("Target module did not receive any input during forward pass")
        features, self._output_features = self._output_features, None
        assert all([x == 1 for x in features.shape[2:]]), f'Invalid feature shape: {features.shape!r}'
        features = torch.flatten(features, start_dim=1)

        return features, preds


def get_model(model: nn.Module, dummy_input: torch.Tensor):
    assert isinstance(model, nn.Sequential)
    head = model[-1]
    wrapped_model = ModuleWrapper(model, head)

    logging.info(f'Input size: {dummy_input.shape!r}')
    with torch.no_grad():
        dummy_embedding, dummy_preds = wrapped_model(dummy_input)
    logging.info(f'Embedding size: {dummy_embedding.shape!r}')
    logging.info(f'Preds size: {dummy_preds.shape!r}')

    return wrapped_model, (dummy_embedding, dummy_preds)
    # print(model[-1])