Loading imgutils/generic/classify.py +28 −4 Original line number Diff line number Diff line Loading @@ -12,6 +12,7 @@ It also handles token-based authentication for accessing private Hugging Face re import json import os import re from threading import Lock from typing import Tuple, Optional, List, Dict, Callable Loading @@ -25,7 +26,7 @@ from huggingface_hub.errors import EntryNotFoundError from ..data import rgb_encode, ImageTyping, load_image from ..preprocess import create_pillow_transforms from ..utils import open_onnx_model, ts_lru_cache from ..utils import open_onnx_model, ts_lru_cache, vnames, vreplace try: import gradio as gr Loading Loading @@ -380,12 +381,35 @@ class ClassifyModel: max_id = np.argmax(output) return self._open_label(model_name)[label_group][max_id], output[max_id].item() def predict_fmt(self, image: ImageTyping, model_name: str, label_group: str = 'default', topk: Optional[int] = 20): def predict_fmt(self, image: ImageTyping, model_name: str, fmt='scores-top5'): d_data = {name: value[0] for name, value in self._raw_predict(image, model_name).items()} scores = d_data['output'] d_labels = self._open_label(model_name) vname_to_spair = {} d_scores = {} for vname in vnames(fmt, str_only=True): matching = re.fullmatch(r'^scores(-top(?P<topk>\d+))?(-(?P<label_group>[a-zA-Z\d_]+))?$', vname) if matching: if matching.group('topk'): topk = int(matching.group('topk')) else: topk = None if matching.group('label_group'): group_label = matching.group('label_group') else: group_label = 'default' vname_to_spair[vname] = (topk, group_label) if (topk, group_label) not in d_scores: d_scores[(topk, group_label)] = _labels_scores_to_topk( labels=d_labels[group_label], scores=scores, topk=topk, ) return vreplace(fmt, mapping={ **d_data, **{vname: d_scores[vpair] for vname, vpair in vname_to_spair.items()}, }) def clear(self): """ Loading Loading
imgutils/generic/classify.py +28 −4 Original line number Diff line number Diff line Loading @@ -12,6 +12,7 @@ It also handles token-based authentication for accessing private Hugging Face re import json import os import re from threading import Lock from typing import Tuple, Optional, List, Dict, Callable Loading @@ -25,7 +26,7 @@ from huggingface_hub.errors import EntryNotFoundError from ..data import rgb_encode, ImageTyping, load_image from ..preprocess import create_pillow_transforms from ..utils import open_onnx_model, ts_lru_cache from ..utils import open_onnx_model, ts_lru_cache, vnames, vreplace try: import gradio as gr Loading Loading @@ -380,12 +381,35 @@ class ClassifyModel: max_id = np.argmax(output) return self._open_label(model_name)[label_group][max_id], output[max_id].item() def predict_fmt(self, image: ImageTyping, model_name: str, label_group: str = 'default', topk: Optional[int] = 20): def predict_fmt(self, image: ImageTyping, model_name: str, fmt='scores-top5'): d_data = {name: value[0] for name, value in self._raw_predict(image, model_name).items()} scores = d_data['output'] d_labels = self._open_label(model_name) vname_to_spair = {} d_scores = {} for vname in vnames(fmt, str_only=True): matching = re.fullmatch(r'^scores(-top(?P<topk>\d+))?(-(?P<label_group>[a-zA-Z\d_]+))?$', vname) if matching: if matching.group('topk'): topk = int(matching.group('topk')) else: topk = None if matching.group('label_group'): group_label = matching.group('label_group') else: group_label = 'default' vname_to_spair[vname] = (topk, group_label) if (topk, group_label) not in d_scores: d_scores[(topk, group_label)] = _labels_scores_to_topk( labels=d_labels[group_label], scores=scores, topk=topk, ) return vreplace(fmt, mapping={ **d_data, **{vname: d_scores[vpair] for vname, vpair in vname_to_spair.items()}, }) def clear(self): """ Loading