Loading zoo/pixai_tagger/export.py +7 −1 Original line number Diff line number Diff line Loading @@ -31,6 +31,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): handler = EndpointHandler(repo_id=src_repo) meta_info = {} meta_info['repo_id'] = src_repo with TemporaryDirectory() as upload_dir: preprocessor = handler.transform preprocessor_file = os.path.join(upload_dir, 'preprocessor.json') Loading Loading @@ -82,6 +83,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): dummy_image = load_image(get_testfile('6125785.jpg'), mode='RGB', force_background='white') dummy_input = handler.transform(dummy_image).unsqueeze(0).to(handler.device) meta_info['input_size'] = dummy_input.shape[-1] flops, params, macs = torch_model_profile_via_calflops(model=handler.model, input_=dummy_input) meta_info['flops'] = flops meta_info['params'] = params Loading @@ -91,8 +93,12 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): with open(new_meta_file, 'w') as f: json.dump(meta_info, f, indent=4, sort_keys=True, ensure_ascii=False) wrapped_model, (conv_features, _, _) = get_model(handler.model, dummy_input) wrapped_model, (conv_features, conv_logits, conv_preds) = get_model(handler.model, dummy_input) conv_features = conv_features.detach().cpu() conv_logits = conv_logits.detach().cpu() conv_preds = conv_preds.detach().cpu() meta_info['num_features'] = conv_features.shape[-1] meta_info['num_classes'] = conv_preds.shape[-1] onnx_filename = os.path.join(upload_dir, 'model.onnx') with TemporaryDirectory() as td: temp_model_onnx = os.path.join(td, 'model.onnx') Loading Loading
zoo/pixai_tagger/export.py +7 −1 Original line number Diff line number Diff line Loading @@ -31,6 +31,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): handler = EndpointHandler(repo_id=src_repo) meta_info = {} meta_info['repo_id'] = src_repo with TemporaryDirectory() as upload_dir: preprocessor = handler.transform preprocessor_file = os.path.join(upload_dir, 'preprocessor.json') Loading Loading @@ -82,6 +83,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): dummy_image = load_image(get_testfile('6125785.jpg'), mode='RGB', force_background='white') dummy_input = handler.transform(dummy_image).unsqueeze(0).to(handler.device) meta_info['input_size'] = dummy_input.shape[-1] flops, params, macs = torch_model_profile_via_calflops(model=handler.model, input_=dummy_input) meta_info['flops'] = flops meta_info['params'] = params Loading @@ -91,8 +93,12 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): with open(new_meta_file, 'w') as f: json.dump(meta_info, f, indent=4, sort_keys=True, ensure_ascii=False) wrapped_model, (conv_features, _, _) = get_model(handler.model, dummy_input) wrapped_model, (conv_features, conv_logits, conv_preds) = get_model(handler.model, dummy_input) conv_features = conv_features.detach().cpu() conv_logits = conv_logits.detach().cpu() conv_preds = conv_preds.detach().cpu() meta_info['num_features'] = conv_features.shape[-1] meta_info['num_classes'] = conv_preds.shape[-1] onnx_filename = os.path.join(upload_dir, 'model.onnx') with TemporaryDirectory() as td: temp_model_onnx = os.path.join(td, 'model.onnx') Loading