Commit c0963685 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add hf token

parent 669b6e0e
Loading
Loading
Loading
Loading
+17 −8
Original line number Diff line number Diff line
@@ -57,6 +57,7 @@ def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),

    :raises TypeError: If the input image is not a PIL Image object.
    """
    # noinspection PyUnresolvedReferences
    image = image.resize(size, Image.BILINEAR)
    data = rgb_encode(image, order_='CHW')

@@ -110,12 +111,12 @@ class ClassifyModel:
        self._labels = {}
        self._hf_token = hf_token

    def _get_hf_token(self):
    def _get_hf_token(self) -> Optional[str]:
        """
        Get the Hugging Face token from the instance variable or environment variable.

        :return: The Hugging Face token.
        :rtype: str
        :rtype: Optional[str]
        """
        return self._hf_token or os.environ.get('HF_TOKEN')

@@ -288,7 +289,7 @@ class ClassifyModel:


@lru_cache()
def _open_models_for_repo_id(repo_id: str) -> ClassifyModel:
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> ClassifyModel:
    """
    Open and cache a ClassifyModel instance for the specified repository ID.

@@ -297,14 +298,17 @@ def _open_models_for_repo_id(repo_id: str) -> ClassifyModel:

    :param repo_id: The repository ID containing the models.
    :type repo_id: str
    :param hf_token: Optional Hugging Face authentication token.
    :type hf_token: Optional[str]

    :return: A ClassifyModel instance for the specified repository.
    :rtype: ClassifyModel
    """
    return ClassifyModel(repo_id)
    return ClassifyModel(repo_id, hf_token=hf_token)


def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str) -> Dict[str, float]:
def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str,
                           hf_token: Optional[str] = None) -> Dict[str, float]:
    """
    Predict the scores for each class using the specified model and repository.

@@ -316,6 +320,8 @@ def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str) ->
    :type repo_id: str
    :param model_name: The name of the model to use for prediction.
    :type model_name: str
    :param hf_token: Optional Hugging Face authentication token.
    :type hf_token: Optional[str]

    :return: A dictionary mapping class labels to their predicted scores.
    :rtype: Dict[str, float]
@@ -323,10 +329,11 @@ def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str) ->
    :raises ValueError: If the model name or repository ID is invalid.
    :raises RuntimeError: If there's an error during prediction.
    """
    return _open_models_for_repo_id(repo_id).predict_score(image, model_name)
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict_score(image, model_name)


def classify_predict(image: ImageTyping, repo_id: str, model_name: str) -> Tuple[str, float]:
def classify_predict(image: ImageTyping, repo_id: str, model_name: str,
                     hf_token: Optional[str] = None) -> Tuple[str, float]:
    """
    Predict the class with the highest score using the specified model and repository.

@@ -338,6 +345,8 @@ def classify_predict(image: ImageTyping, repo_id: str, model_name: str) -> Tuple
    :type repo_id: str
    :param model_name: The name of the model to use for prediction.
    :type model_name: str
    :param hf_token: Optional Hugging Face authentication token.
    :type hf_token: Optional[str]

    :return: A tuple containing the predicted class label and its score.
    :rtype: Tuple[str, float]
@@ -345,4 +354,4 @@ def classify_predict(image: ImageTyping, repo_id: str, model_name: str) -> Tuple
    :raises ValueError: If the model name or repository ID is invalid.
    :raises RuntimeError: If there's an error during prediction.
    """
    return _open_models_for_repo_id(repo_id).predict(image, model_name)
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict(image, model_name)
+8 −5
Original line number Diff line number Diff line
@@ -298,7 +298,7 @@ class YOLOModel:
        self._models = {}
        self._hf_token = hf_token

    def _get_hf_token(self):
    def _get_hf_token(self) -> Optional[str]:
        """
        Get the Hugging Face token, either from the instance or environment variable.

@@ -408,7 +408,7 @@ class YOLOModel:


@lru_cache()
def _open_models_for_repo_id(repo_id: str) -> YOLOModel:
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> YOLOModel:
    """
    Load and cache a YOLO model from a Hugging Face repository.

@@ -417,6 +417,8 @@ def _open_models_for_repo_id(repo_id: str) -> YOLOModel:

    :param repo_id: The Hugging Face repository ID for the YOLO model.
    :type repo_id: str
    :param hf_token: Optional Hugging Face authentication token.
    :type hf_token: Optional[str]

    :return: The loaded YOLO model.
    :rtype: YOLOModel
@@ -428,11 +430,12 @@ def _open_models_for_repo_id(repo_id: str) -> YOLOModel:
        >>> # Subsequent calls with the same repo_id will return the cached model
        >>> same_model = _open_models_for_repo_id("yolov5/yolov5s")
    """
    return YOLOModel(repo_id)
    return YOLOModel(repo_id, hf_token=hf_token)


def yolo_predict(image: ImageTyping, repo_id: str, model_name: str,
                 conf_threshold: float = 0.25, iou_threshold: float = 0.7) \
                 conf_threshold: float = 0.25, iou_threshold: float = 0.7,
                 hf_token: Optional[str] = None) \
        -> List[Tuple[Tuple[int, int, int, int], str, float]]:
    """
    Perform object detection on an image using a YOLO model from a Hugging Face repository.
@@ -464,7 +467,7 @@ def yolo_predict(image: ImageTyping, repo_id: str, model_name: str,
    >>> print(detections[0])  # First detection
    ((100, 200, 300, 400), 'person', 0.95)
    """
    return _open_models_for_repo_id(repo_id).predict(
    return _open_models_for_repo_id(repo_id, hf_token=hf_token).predict(
        image=image,
        model_name=model_name,
        conf_threshold=conf_threshold,