Commit 150b42f8 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add some more meta information

parent 18a9bad5
Loading
Loading
Loading
Loading
+7 −1
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):

    handler = EndpointHandler(repo_id=src_repo)
    meta_info = {}
    meta_info['repo_id'] = src_repo
    with TemporaryDirectory() as upload_dir:
        preprocessor = handler.transform
        preprocessor_file = os.path.join(upload_dir, 'preprocessor.json')
@@ -82,6 +83,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):

        dummy_image = load_image(get_testfile('6125785.jpg'), mode='RGB', force_background='white')
        dummy_input = handler.transform(dummy_image).unsqueeze(0).to(handler.device)
        meta_info['input_size'] = dummy_input.shape[-1]
        flops, params, macs = torch_model_profile_via_calflops(model=handler.model, input_=dummy_input)
        meta_info['flops'] = flops
        meta_info['params'] = params
@@ -91,8 +93,12 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
        with open(new_meta_file, 'w') as f:
            json.dump(meta_info, f, indent=4, sort_keys=True, ensure_ascii=False)

        wrapped_model, (conv_features, _, _) = get_model(handler.model, dummy_input)
        wrapped_model, (conv_features, conv_logits, conv_preds) = get_model(handler.model, dummy_input)
        conv_features = conv_features.detach().cpu()
        conv_logits = conv_logits.detach().cpu()
        conv_preds = conv_preds.detach().cpu()
        meta_info['num_features'] = conv_features.shape[-1]
        meta_info['num_classes'] = conv_preds.shape[-1]
        onnx_filename = os.path.join(upload_dir, 'model.onnx')
        with TemporaryDirectory() as td:
            temp_model_onnx = os.path.join(td, 'model.onnx')