Commit f40f6991 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add model torch code inside

parent e2142d0c
Loading
Loading
Loading
Loading
+7 −3
Original line number Diff line number Diff line
@@ -981,9 +981,13 @@ if __name__ == '__main__':

    print(model)

    # # Define example input – a dummy image tensor of the expected input shape (1, 3, 512, 512)
    # dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32)
    #
    # Define example input – a dummy image tensor of the expected input shape (1, 3, 512, 512)
    dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32)
    with torch.no_grad():
        dummy_init_logits, dummy_refined_logits = model(dummy_input)
    print(dummy_init_logits.shape, dummy_init_logits.dtype)
    print(dummy_refined_logits.shape, dummy_refined_logits.dtype)

    # # Export to ONNX
    # onnx_path = "camie_tagger_initial_v15.onnx"
    # torch.onnx.export(
+7 −2
Original line number Diff line number Diff line
@@ -188,8 +188,13 @@ if __name__ == '__main__':
    # (Optional) Cast to float32 if weights were in half precision
    # model = model.float()

    # # --- Export to ONNX ---
    # dummy_input = torch.randn(1, 3, 512, 512, requires_grad=False)  # dummy batch of 1 image (3x512x512)
    # --- Export to ONNX ---
    dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32)
    with torch.no_grad():
        dummy_init_logits, dummy_refined_logits = model(dummy_input)
    print(dummy_init_logits.shape, dummy_init_logits.dtype)
    print(dummy_refined_logits.shape, dummy_refined_logits.dtype)

    # output_onnx_file = "camie_refined_no_flash_v15.onnx"
    # torch.onnx.export(
    #     model, dummy_input, output_onnx_file,