Commit 49efe9b6 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): update the classifier system

parent 68d41f94
Loading
Loading
Loading
Loading
+24 −5
Original line number Diff line number Diff line
@@ -217,7 +217,7 @@ class ClassifyModel:

        return self._models[model_name]

    def _open_label(self, model_name: str) -> Dict[str, List[str]]:
    def _open_label(self, model_name: str) -> Dict[str, np.ndarray]:
        """
        Load and cache model labels from metadata.

@@ -228,7 +228,7 @@ class ClassifyModel:
        :type model_name: str

        :return: List of model labels
        :rtype: Dict[str, List[str]]
        :rtype: Dict[str, np.ndarray]

        :raises RuntimeError: If label loading fails
        """
@@ -241,10 +241,14 @@ class ClassifyModel:
                        token=self._get_hf_token(),
                ), 'r') as f:
                    meta_info = json.load(f)
                    self._labels[model_name] = {
                    d_groups = {
                        **(meta_info.get('other_labels') or {}),
                        'default': meta_info['labels']
                    }
                    self._labels[model_name] = {
                        key: np.array(labels)
                        for key, labels in d_groups.items()
                    }

        return self._labels[model_name]

@@ -307,7 +311,8 @@ class ClassifyModel:
        output, = self._open_model(model_name).run(['output'], {'input': input_})
        return output

    def predict_score(self, image: ImageTyping, model_name: str, label_group: str = 'default') -> Dict[str, float]:
    def predict_score(self, image: ImageTyping, model_name: str,
                      label_group: str = 'default', topk: Optional[int] = 20) -> Dict[str, float]:
        """
        Predict the scores for each class using the specified model.

@@ -317,6 +322,10 @@ class ClassifyModel:
        :type image: ImageTyping
        :param model_name: The name of the model to use for prediction.
        :type model_name: str
        :param label_group: Label group for the classification result.
        :type label_group: str
        :param topk: Top-K result. Default is 20, return all results when None ia assigned.
        :type topk: Optional[int]

        :return: A dictionary mapping class labels to their predicted scores.
        :rtype: Dict[str, float]
@@ -325,7 +334,15 @@ class ClassifyModel:
        :raises RuntimeError: If there's an error during prediction.
        """
        output = self._raw_predict(image, model_name)
        values = dict(zip(self._open_label(model_name)[label_group], map(lambda x: x.item(), output[0])))
        labels = self._open_label(model_name)[label_group]
        scores = output[0]
        if topk and topk < labels.shape[-1]:
            indices = np.argpartition(scores, -topk)[-topk:]
            indices = indices[np.argsort(-scores[indices], kind='mergesort')]
            labels, scores = labels[indices], scores[indices]

        # noinspection PyTypeChecker
        values = dict(zip(labels.tolist(), scores.tolist()))
        return values

    def predict(self, image: ImageTyping, model_name: str, label_group: str = 'default') -> Tuple[str, float]:
@@ -338,6 +355,8 @@ class ClassifyModel:
        :type image: ImageTyping
        :param model_name: The name of the model to use for prediction.
        :type model_name: str
        :param label_group: Label group for the classification result.
        :type label_group: str

        :return: A tuple containing the predicted class label and its score.
        :rtype: Tuple[str, float]