Commit c693aeef authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add demo webui for CLIP

parent 0815ab7a
Loading
Loading
Loading
Loading
+77 −2
Original line number Diff line number Diff line
import json
import os
from threading import Lock
from typing import List, Union, Optional, Any
from typing import List, Union, Optional, Any, Dict

import numpy as np
from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_normpath, parse_hf_fs_path, hf_fs_path
from huggingface_hub import hf_hub_download, HfFileSystem
from tokenizers import Tokenizer

from imgutils.data import MultiImagesTyping, load_images
from imgutils.data import MultiImagesTyping, load_images, ImageTyping
from imgutils.preprocess import create_pillow_transforms
from imgutils.utils import open_onnx_model, vreplace, ts_lru_cache

@@ -264,6 +266,79 @@ class CLIPModel:
        self._text_tokenizers.clear()
        self._logit_scales.clear()

    def make_ui(self, default_model_name: Optional[str] = None):
        _check_gradio_env()
        model_list = self.model_names
        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
            selected_model_name, selected_time = None, None
            for fileitem in hf_client.get_paths_info(
                    repo_id=self.repo_id,
                    repo_type='model',
                    paths=[f'{model_name}/image_encode.onnx' for model_name in model_list],
                    expand=True,
            ):
                if not selected_time or fileitem.last_commit.date > selected_time:
                    selected_model_name = os.path.dirname(fileitem.path)
                    selected_time = fileitem.last_commit.date
            default_model_name = selected_model_name

        def _gr_detect(image: ImageTyping, raw_text: str, model_name: str) -> Dict[str, float]:
            labels = []
            for line in raw_text.splitlines(keepends=False):
                line = line.strip()
                if line:
                    labels.append(line)

            prediction = self.predict(images=[image], texts=labels, model_name=model_name)[0]
            return dict(zip(labels, prediction.tolist()))

        with gr.Row():
            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    gr_raw_text = gr.TextArea(value='', lines=5, autoscroll=True, label='Labels',
                                              placeholder='Enter labels, one per line')
                with gr.Row():
                    gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')

                gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                gr_output_labels = gr.Label(label='Prediction')

            gr_submit.click(
                _gr_detect,
                inputs=[
                    gr_input_image,
                    gr_raw_text,
                    gr_model,
                ],
                outputs=[gr_output_labels],
            )

    def launch_demo(self, default_model_name: Optional[str] = None,
                    server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
        _check_gradio_env()
        with gr.Blocks() as demo:
            with gr.Row():
                with gr.Column():
                    repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
                    gr.HTML(f'<h2 style="text-align: center;">CLIP Demo For {self.repo_id}</h2>')
                    gr.Markdown(f'This is the quick demo for CLIP model [{self.repo_id}]({repo_url}). '
                                f'Powered by `dghs-imgutils`\'s quick demo module.')

            with gr.Row():
                self.make_ui(
                    default_model_name=default_model_name,
                )

        demo.launch(
            server_name=server_name,
            server_port=server_port,
            **kwargs,
        )


@ts_lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> CLIPModel:
+2 −1
Original line number Diff line number Diff line
@@ -438,7 +438,8 @@ class SigLIPModel:
            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    gr_raw_text = gr.TextArea(value='', lines=5, autoscroll=True, label='Labels')
                    gr_raw_text = gr.TextArea(value='', lines=5, autoscroll=True, label='Labels',
                                              placeholder='Enter labels, one per line')
                with gr.Row():
                    gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')