Commit deb215eb authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): save the predict fmt code

parent 9455d1ff
Loading
Loading
Loading
Loading
+28 −4
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ It also handles token-based authentication for accessing private Hugging Face re

import json
import os
import re
from threading import Lock
from typing import Tuple, Optional, List, Dict, Callable

@@ -25,7 +26,7 @@ from huggingface_hub.errors import EntryNotFoundError

from ..data import rgb_encode, ImageTyping, load_image
from ..preprocess import create_pillow_transforms
from ..utils import open_onnx_model, ts_lru_cache
from ..utils import open_onnx_model, ts_lru_cache, vnames, vreplace

try:
    import gradio as gr
@@ -380,12 +381,35 @@ class ClassifyModel:
        max_id = np.argmax(output)
        return self._open_label(model_name)[label_group][max_id], output[max_id].item()

    def predict_fmt(self, image: ImageTyping, model_name: str,
                      label_group: str = 'default', topk: Optional[int] = 20):
    def predict_fmt(self, image: ImageTyping, model_name: str, fmt='scores-top5'):
        d_data = {name: value[0] for name, value in self._raw_predict(image, model_name).items()}
        scores = d_data['output']
        d_labels = self._open_label(model_name)
        vname_to_spair = {}
        d_scores = {}
        for vname in vnames(fmt, str_only=True):
            matching = re.fullmatch(r'^scores(-top(?P<topk>\d+))?(-(?P<label_group>[a-zA-Z\d_]+))?$', vname)
            if matching:
                if matching.group('topk'):
                    topk = int(matching.group('topk'))
                else:
                    topk = None
                if matching.group('label_group'):
                    group_label = matching.group('label_group')
                else:
                    group_label = 'default'
                vname_to_spair[vname] = (topk, group_label)
                if (topk, group_label) not in d_scores:
                    d_scores[(topk, group_label)] = _labels_scores_to_topk(
                        labels=d_labels[group_label],
                        scores=scores,
                        topk=topk,
                    )


        return vreplace(fmt, mapping={
            **d_data,
            **{vname: d_scores[vpair] for vname, vpair in vname_to_spair.items()},
        })

    def clear(self):
        """