Unverified Commit 6a3498d6 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #77 from deepghs/dev/cls

dev(narugo): add better generic clasisification
parents c4d10627 3d33bbf8
Loading
Loading
Loading
Loading
+10 −1
Original line number Diff line number Diff line
@@ -81,6 +81,15 @@ class ClassifyModel:

    def _raw_predict(self, image: ImageTyping, model_name: str):
        image = load_image(image, force_background='white', mode='RGB')
        model = self._open_model(model_name)
        batch, channels, height, width = model.get_inputs()[0].shape
        if channels != 3:
            raise RuntimeError(f'Model {model_name!r} required {[batch, channels, height, width]!r}, '
                               f'channels not 3.')  # pragma: no cover

        if isinstance(height, int) and isinstance(width, int):
            input_ = _img_encode(image, size=(width, height))[None, ...]
        else:
            input_ = _img_encode(image)[None, ...]
        output, = self._open_model(model_name).run(['output'], {'input': input_})
        return output