Commit 68d41f94 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add multi label groups for classifiers

parent 9f7f532a
Loading
Loading
Loading
Loading
+36 −10
Original line number Diff line number Diff line
@@ -217,7 +217,7 @@ class ClassifyModel:

        return self._models[model_name]

    def _open_label(self, model_name: str) -> List[str]:
    def _open_label(self, model_name: str) -> Dict[str, List[str]]:
        """
        Load and cache model labels from metadata.

@@ -228,7 +228,7 @@ class ClassifyModel:
        :type model_name: str

        :return: List of model labels
        :rtype: List[str]
        :rtype: Dict[str, List[str]]

        :raises RuntimeError: If label loading fails
        """
@@ -240,7 +240,11 @@ class ClassifyModel:
                        f'{model_name}/meta.json',
                        token=self._get_hf_token(),
                ), 'r') as f:
                    self._labels[model_name] = json.load(f)['labels']
                    meta_info = json.load(f)
                    self._labels[model_name] = {
                        **(meta_info.get('other_labels') or {}),
                        'default': meta_info['labels']
                    }

        return self._labels[model_name]

@@ -303,7 +307,7 @@ class ClassifyModel:
        output, = self._open_model(model_name).run(['output'], {'input': input_})
        return output

    def predict_score(self, image: ImageTyping, model_name: str) -> Dict[str, float]:
    def predict_score(self, image: ImageTyping, model_name: str, label_group: str = 'default') -> Dict[str, float]:
        """
        Predict the scores for each class using the specified model.

@@ -321,10 +325,10 @@ class ClassifyModel:
        :raises RuntimeError: If there's an error during prediction.
        """
        output = self._raw_predict(image, model_name)
        values = dict(zip(self._open_label(model_name), map(lambda x: x.item(), output[0])))
        values = dict(zip(self._open_label(model_name)[label_group], map(lambda x: x.item(), output[0])))
        return values

    def predict(self, image: ImageTyping, model_name: str) -> Tuple[str, float]:
    def predict(self, image: ImageTyping, model_name: str, label_group: str = 'default') -> Tuple[str, float]:
        """
        Predict the class with the highest score for the given image.

@@ -343,7 +347,7 @@ class ClassifyModel:
        """
        output = self._raw_predict(image, model_name)[0]
        max_id = np.argmax(output)
        return self._open_label(model_name)[max_id], output[max_id].item()
        return self._open_label(model_name)[label_group][max_id], output[max_id].item()

    def clear(self):
        """
@@ -391,21 +395,43 @@ class ClassifyModel:

        with gr.Row():
            with gr.Column():
                with gr.Row():
                    gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                    gr_label_group = gr.Dropdown(list(self._open_label(default_model_name).keys()),
                                                 value='default', label='Label Group')
                with gr.Row():
                    gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                gr_output = gr.Label(label='Prediction')

            def _fn_label_group(new_model_name, old_label_group):
                labels_info = self._open_label(new_model_name)
                return gr.Dropdown(
                    list(labels_info.keys()),
                    value=old_label_group if old_label_group in labels_info else 'default',
                    label='Label Group'
                )

            gr_submit.click(
                self.predict_score,
                inputs=[
                    gr_input_image,
                    gr_model,
                    gr_label_group,
                ],
                outputs=[gr_output],
            )
            gr_model.change(
                _fn_label_group,
                inputs=[
                    gr_model,
                    gr_label_group
                ],
                outputs=[gr_label_group]
            )

    def launch_demo(self, default_model_name: Optional[str] = None,
                    server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):