Commit 02fb5b1f authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix convertion bug, ci skip

parent 77db215e
Loading
Loading
Loading
Loading
+11 −3
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ import pandas as pd
import torch
from ditk import logging
from hbutils.system import TemporaryDirectory
from hfutils.operate import get_hf_client
from hfutils.operate import get_hf_client, upload_directory_as_directory

from imgutils.data import load_image
from imgutils.preprocess import parse_torchvision_transforms
@@ -77,6 +77,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
        dummy_image = load_image(get_testfile('6125785.jpg'), mode='RGB', force_background='white')
        dummy_input = handler.transform(dummy_image).unsqueeze(0).to(handler.device)
        wrapped_model, (conv_features, _) = get_model(handler.model, dummy_input)
        conv_features = conv_features.detach().cpu()
        onnx_filename = os.path.join(upload_dir, 'model.onnx')
        with TemporaryDirectory() as td:
            temp_model_onnx = os.path.join(td, 'model.onnx')
@@ -111,7 +112,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
            onnx.save(model, onnx_filename)

            session = onnxruntime.InferenceSession(onnx_filename)
            o_embeddings, = session.run(['embedding'], {'input': dummy_input.numpy()})
            o_embeddings, = session.run(['embedding'], {'input': dummy_input.detach().cpu().numpy()})
            emb_1 = o_embeddings / np.linalg.norm(o_embeddings, axis=-1, keepdims=True)
            emb_2 = conv_features.numpy() / np.linalg.norm(conv_features.numpy(), axis=-1, keepdims=True)
            emb_sims = (emb_1 * emb_2).sum()
@@ -119,7 +120,14 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
            assert emb_sims >= 0.98, f'Similarity of the embeddings is {emb_sims:.5f}, ONNX validation failed.'

        os.system(f'tree {upload_dir!r}')
        input()
        upload_directory_as_directory(
            repo_id=dst_repo,
            repo_type='model',
            local_directory=upload_dir,
            path_in_repo='.',
            message=f'Upload ONNX export of model {src_repo!r}'
        )

    # print(df_tags)

    # pprint(handler.index_to_tag_map)