Commit 632f2066 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix problem of onnx export

parent 140f4648
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -35,11 +35,11 @@ def load_image(image: ImageTyping, mode=None, force_background: Optional[str] =
    return image


def load_images(images: MultiImagesTyping, mode=None) -> List[Image.Image]:
def load_images(images: MultiImagesTyping, mode=None, force_background: Optional[str] = 'white') -> List[Image.Image]:
    if not isinstance(images, (list, tuple)):
        images = [images]

    return [load_image(item, mode) for item in images]
    return [load_image(item, mode, force_background) for item in images]


def add_background_for_rgba(image: ImageTyping, background: str = 'white'):
+2 −2
Original line number Diff line number Diff line
@@ -148,7 +148,7 @@ def onnx_check(model: str, check_item: Optional[str] = None, verbose: bool = Fal


MODELS = [
    # ('caformer', 'ccip-caformer-2_fp32.ckpt'),
    ('caformer', 'ccip-caformer-2_fp32.ckpt'),
    ('caformer', 'ccip-caformer-4_fp32.ckpt'),
]

@@ -183,7 +183,7 @@ def export(output_dir: str, verbose: bool = False, threshold_samples: int = 500)
                onnx_filename = os.path.join(output_dir or td, f'{ckpt_body}_{item}.onnx')
                export_func = _CHECK_ITEMS[item]
                try:
                    model = CCIP(model_name)  # necessary
                    model, preprocess = _get_model_from_ckpt(model_name, ckpt_file, device='cpu', fp16=False)
                    if verbose:
                        export_func(model, threshold, onnx_filename, verbose=verbose)
                    else:
+1 −1
Original line number Diff line number Diff line
@@ -60,7 +60,7 @@ class LogitToConfidence(nn.Module):
        self.threshold: torch.Tensor

    def forward(self, x):
        ex = (x - self.threshold)
        ex = x - self.threshold
        return torch.exp(ex) / (torch.exp(ex) + 1.0)