Commit fbd90d6f authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix bug in cli

parent b14156f8
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -53,7 +53,8 @@ def export(output_dir: str):
        model = _KNOWN_MODELS[model_name]().float()
        ckpt_file = hf_hub_download('deepghs/imgutils-models', f'monochrome/{ckpt}')
        model.load_state_dict(torch.load(ckpt_file, map_location='cpu'))
        output_file = os.path.join(output_dir, os.path.basename(ckpt))
        filebody, _ = os.path.splitext(ckpt)
        output_file = os.path.join(output_dir, f'{filebody}.onnx')
        export_model_to_onnx(model, output_file, feature_bins=feature_bins)