Commit 9f7f532a authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update cls test

parent a74b573c
Loading
Loading
Loading
Loading
+28 −3
Original line number Diff line number Diff line
@@ -21,8 +21,10 @@ from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_fs_path, hf_normpath
from huggingface_hub import hf_hub_download, HfFileSystem
from huggingface_hub.errors import EntryNotFoundError

from ..data import rgb_encode, ImageTyping, load_image
from ..preprocess import create_pillow_transforms
from ..utils import open_onnx_model, ts_lru_cache

try:
@@ -133,6 +135,7 @@ class ClassifyModel:
        self._model_names = None
        self._models = {}
        self._labels = {}
        self._preprocesses = {}
        self._hf_token = hf_token
        self._global_lock = Lock()
        self._model_lock = Lock()
@@ -241,6 +244,24 @@ class ClassifyModel:

        return self._labels[model_name]

    def _open_preprocess(self, model_name: str):
        with self._model_lock:
            if model_name not in self._preprocesses:
                try:
                    pfile = hf_hub_download(
                        self.repo_id,
                        f'{model_name}/preprocess.json',
                        token=self._get_hf_token(),
                    )
                except EntryNotFoundError:
                    self._preprocesses[model_name] = None
                else:
                    with open(pfile, 'r') as f:
                        stages_info = json.load(f)['stages']
                        self._preprocesses[model_name] = create_pillow_transforms(stages_info)

            return self._preprocesses[model_name]

    def _raw_predict(self, image: ImageTyping, model_name: str):
        """
        Generate raw model predictions for an input image.
@@ -271,6 +292,10 @@ class ClassifyModel:
        if self._fn_preprocess:
            image = self._fn_preprocess(image)

        preprocess = self._open_preprocess(model_name=model_name)
        if preprocess:
            input_ = preprocess(image)[None, ...]
        else:
            if isinstance(height, int) and isinstance(width, int):
                input_ = _img_encode(image, size=(width, height))[None, ...]
            else: