Commit 4554e36d authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add preprocessor

parent d43dcae5
Loading
Loading
Loading
Loading
+10 −2
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ It also handles token-based authentication for accessing private Hugging Face re
import json
import os
from threading import Lock
from typing import Tuple, Optional, List, Dict
from typing import Tuple, Optional, List, Dict, Callable

import numpy as np
from PIL import Image
@@ -88,6 +88,9 @@ def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
    return data.astype(np.float32)


ImagePreprocessFunc = Callable[[Image.Image], Image.Image]


class ClassifyModel:
    """
    A class for managing and using classification models.
@@ -114,7 +117,8 @@ class ClassifyModel:
        >>> print(f"Predicted class: {prediction}, Score: {score}")
    """

    def __init__(self, repo_id: str, hf_token: Optional[str] = None):
    def __init__(self, repo_id: str, fn_preprocess: Optional[ImagePreprocessFunc] = None,
                 hf_token: Optional[str] = None):
        """
        Initialize the ClassifyModel instance.

@@ -124,6 +128,7 @@ class ClassifyModel:
        :type hf_token: Optional[str], optional
        """
        self.repo_id = repo_id
        self._fn_preprocess = fn_preprocess
        self._model_names = None
        self._models = {}
        self._labels = {}
@@ -254,6 +259,9 @@ class ClassifyModel:
            raise RuntimeError(f'Model {model_name!r} required {[batch, channels, height, width]!r}, '
                               f'channels not 3.')  # pragma: no cover

        if self._fn_preprocess:
            image = self._fn_preprocess(image)

        if isinstance(height, int) and isinstance(width, int):
            input_ = _img_encode(image, size=(width, height))[None, ...]
        else: