Commit e8bda959 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add webui

parent d73611dc
Loading
Loading
Loading
Loading
+92 −2
Original line number Diff line number Diff line
import json
import os
from threading import Lock
from typing import List, Union, Optional, Any
from typing import List, Union, Optional, Any, Dict

import numpy as np
from hfutils.operate import get_hf_client
from hfutils.repository import hf_hub_repo_url
from hfutils.utils import hf_normpath, hf_fs_path, parse_hf_fs_path
from huggingface_hub import hf_hub_download, HfFileSystem
from tokenizers import Tokenizer

from ..data import MultiImagesTyping, load_images
from ..data import MultiImagesTyping, load_images, ImageTyping
from ..preprocess import create_pillow_transforms
from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache

try:
    import gradio as gr
except (ImportError, ModuleNotFoundError):
    gr = None

__all__ = [
    'SigLIPModel',
    'siglip_image_encode',
@@ -20,6 +27,17 @@ __all__ = [
]


def _check_gradio_env():
    """
    Check if the Gradio library is installed and available.

    :raises EnvironmentError: If Gradio is not installed.
    """
    if gr is None:
        raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'
                               f'Please install it with `pip install dghs-imgutils[demo]`.')


class SigLIPModel:
    def __init__(self, repo_id: str, hf_token: Optional[str] = None):
        self.repo_id = repo_id
@@ -320,6 +338,78 @@ class SigLIPModel:
        self._text_tokenizers.clear()
        self._logit_scales.clear()

    def make_ui(self, default_model_name: Optional[str] = None):
        _check_gradio_env()
        model_list = self.model_names
        if not default_model_name:
            hf_client = get_hf_client(hf_token=self._get_hf_token())
            selected_model_name, selected_time = None, None
            for fileitem in hf_client.get_paths_info(
                    repo_id=self.repo_id,
                    repo_type='model',
                    paths=[f'{model_name}/image_encode.onnx' for model_name in model_list],
                    expand=True,
            ):
                if not selected_time or fileitem.last_commit.date > selected_time:
                    selected_model_name = os.path.dirname(fileitem.path)
                    selected_time = fileitem.last_commit.date
            default_model_name = selected_model_name

        def _gr_detect(image: ImageTyping, raw_text: str, model_name: str) -> Dict[str, float]:
            labels = []
            for line in raw_text.splitlines(keepends=False):
                line = line.strip()
                if line:
                    labels.append(line)

            prediction = self.predict(images=[image], texts=labels, model_name=model_name)[0]
            return dict(zip(labels, prediction.tolist()))

        with gr.Row():
            with gr.Column():
                gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    gr_raw_text = gr.TextArea(value='', lines=5, autoscroll=True, label='Labels')
                with gr.Row():
                    gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')

                gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                gr_output_labels = gr.Label(label='Prediction')

            gr_submit.click(
                _gr_detect,
                inputs=[
                    gr_input_image,
                    gr_raw_text,
                    gr_model,
                ],
                outputs=[gr_output_labels],
            )

    def launch_demo(self, default_model_name: Optional[str] = None,
                    server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
        _check_gradio_env()
        with gr.Blocks() as demo:
            with gr.Row():
                with gr.Column():
                    repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
                    gr.HTML(f'<h2 style="text-align: center;">SigLIP Demo For {self.repo_id}</h2>')
                    gr.Markdown(f'This is the quick demo for SigLIP model [{self.repo_id}]({repo_url}). '
                                f'Powered by `dghs-imgutils`\'s quick demo module.')

            with gr.Row():
                self.make_ui(
                    default_model_name=default_model_name,
                )

        demo.launch(
            server_name=server_name,
            server_port=server_port,
            **kwargs,
        )


@ts_lru_cache()
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) -> SigLIPModel: