Loading imgutils/generic/classify.py +36 −10 Original line number Diff line number Diff line Loading @@ -217,7 +217,7 @@ class ClassifyModel: return self._models[model_name] def _open_label(self, model_name: str) -> List[str]: def _open_label(self, model_name: str) -> Dict[str, List[str]]: """ Load and cache model labels from metadata. Loading @@ -228,7 +228,7 @@ class ClassifyModel: :type model_name: str :return: List of model labels :rtype: List[str] :rtype: Dict[str, List[str]] :raises RuntimeError: If label loading fails """ Loading @@ -240,7 +240,11 @@ class ClassifyModel: f'{model_name}/meta.json', token=self._get_hf_token(), ), 'r') as f: self._labels[model_name] = json.load(f)['labels'] meta_info = json.load(f) self._labels[model_name] = { **(meta_info.get('other_labels') or {}), 'default': meta_info['labels'] } return self._labels[model_name] Loading Loading @@ -303,7 +307,7 @@ class ClassifyModel: output, = self._open_model(model_name).run(['output'], {'input': input_}) return output def predict_score(self, image: ImageTyping, model_name: str) -> Dict[str, float]: def predict_score(self, image: ImageTyping, model_name: str, label_group: str = 'default') -> Dict[str, float]: """ Predict the scores for each class using the specified model. Loading @@ -321,10 +325,10 @@ class ClassifyModel: :raises RuntimeError: If there's an error during prediction. """ output = self._raw_predict(image, model_name) values = dict(zip(self._open_label(model_name), map(lambda x: x.item(), output[0]))) values = dict(zip(self._open_label(model_name)[label_group], map(lambda x: x.item(), output[0]))) return values def predict(self, image: ImageTyping, model_name: str) -> Tuple[str, float]: def predict(self, image: ImageTyping, model_name: str, label_group: str = 'default') -> Tuple[str, float]: """ Predict the class with the highest score for the given image. Loading @@ -343,7 +347,7 @@ class ClassifyModel: """ output = self._raw_predict(image, model_name)[0] max_id = np.argmax(output) return self._open_label(model_name)[max_id], output[max_id].item() return self._open_label(model_name)[label_group][max_id], output[max_id].item() def clear(self): """ Loading Loading @@ -391,21 +395,43 @@ class ClassifyModel: with gr.Row(): with gr.Column(): with gr.Row(): gr_input_image = gr.Image(type='pil', label='Original Image') with gr.Row(): gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model') gr_label_group = gr.Dropdown(list(self._open_label(default_model_name).keys()), value='default', label='Label Group') with gr.Row(): gr_submit = gr.Button(value='Submit', variant='primary') with gr.Column(): gr_output = gr.Label(label='Prediction') def _fn_label_group(new_model_name, old_label_group): labels_info = self._open_label(new_model_name) return gr.Dropdown( list(labels_info.keys()), value=old_label_group if old_label_group in labels_info else 'default', label='Label Group' ) gr_submit.click( self.predict_score, inputs=[ gr_input_image, gr_model, gr_label_group, ], outputs=[gr_output], ) gr_model.change( _fn_label_group, inputs=[ gr_model, gr_label_group ], outputs=[gr_label_group] ) def launch_demo(self, default_model_name: Optional[str] = None, server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs): Loading Loading
imgutils/generic/classify.py +36 −10 Original line number Diff line number Diff line Loading @@ -217,7 +217,7 @@ class ClassifyModel: return self._models[model_name] def _open_label(self, model_name: str) -> List[str]: def _open_label(self, model_name: str) -> Dict[str, List[str]]: """ Load and cache model labels from metadata. Loading @@ -228,7 +228,7 @@ class ClassifyModel: :type model_name: str :return: List of model labels :rtype: List[str] :rtype: Dict[str, List[str]] :raises RuntimeError: If label loading fails """ Loading @@ -240,7 +240,11 @@ class ClassifyModel: f'{model_name}/meta.json', token=self._get_hf_token(), ), 'r') as f: self._labels[model_name] = json.load(f)['labels'] meta_info = json.load(f) self._labels[model_name] = { **(meta_info.get('other_labels') or {}), 'default': meta_info['labels'] } return self._labels[model_name] Loading Loading @@ -303,7 +307,7 @@ class ClassifyModel: output, = self._open_model(model_name).run(['output'], {'input': input_}) return output def predict_score(self, image: ImageTyping, model_name: str) -> Dict[str, float]: def predict_score(self, image: ImageTyping, model_name: str, label_group: str = 'default') -> Dict[str, float]: """ Predict the scores for each class using the specified model. Loading @@ -321,10 +325,10 @@ class ClassifyModel: :raises RuntimeError: If there's an error during prediction. """ output = self._raw_predict(image, model_name) values = dict(zip(self._open_label(model_name), map(lambda x: x.item(), output[0]))) values = dict(zip(self._open_label(model_name)[label_group], map(lambda x: x.item(), output[0]))) return values def predict(self, image: ImageTyping, model_name: str) -> Tuple[str, float]: def predict(self, image: ImageTyping, model_name: str, label_group: str = 'default') -> Tuple[str, float]: """ Predict the class with the highest score for the given image. Loading @@ -343,7 +347,7 @@ class ClassifyModel: """ output = self._raw_predict(image, model_name)[0] max_id = np.argmax(output) return self._open_label(model_name)[max_id], output[max_id].item() return self._open_label(model_name)[label_group][max_id], output[max_id].item() def clear(self): """ Loading Loading @@ -391,21 +395,43 @@ class ClassifyModel: with gr.Row(): with gr.Column(): with gr.Row(): gr_input_image = gr.Image(type='pil', label='Original Image') with gr.Row(): gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model') gr_label_group = gr.Dropdown(list(self._open_label(default_model_name).keys()), value='default', label='Label Group') with gr.Row(): gr_submit = gr.Button(value='Submit', variant='primary') with gr.Column(): gr_output = gr.Label(label='Prediction') def _fn_label_group(new_model_name, old_label_group): labels_info = self._open_label(new_model_name) return gr.Dropdown( list(labels_info.keys()), value=old_label_group if old_label_group in labels_info else 'default', label='Label Group' ) gr_submit.click( self.predict_score, inputs=[ gr_input_image, gr_model, gr_label_group, ], outputs=[gr_output], ) gr_model.change( _fn_label_group, inputs=[ gr_model, gr_label_group ], outputs=[gr_label_group] ) def launch_demo(self, default_model_name: Optional[str] = None, server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs): Loading