Commit 90b06731 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add multilabel infer code

parent 3d1ae0bd
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -5,5 +5,6 @@ Overview:
from .classify import *
from .clip import *
from .enhance import *
from .multilabel import *
from .siglip import *
from .yolo import *
+324 −0
Original line number Diff line number Diff line
import json
import os
import warnings
from threading import Lock
from typing import Optional, Literal, Dict, Any, Union, Tuple

import pandas as pd
from hbutils.string import titleize
from hfutils.repository import hf_hub_repo_url
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError

from imgutils.data import ImageTyping, load_image
from imgutils.utils import vreplace, ts_lru_cache
from ..preprocess import create_pillow_transforms
from ..utils import open_onnx_model

try:
    import gradio as gr
except (ImportError, ModuleNotFoundError):
    gr = None

__all__ = [
    'MultiLabelTIMMModel',
    'multilabel_predict',
]


def _check_gradio_env():
    """
    Verify that Gradio library is properly installed and available.

    This function checks if the Gradio package is accessible for creating
    web-based demos. If Gradio is not found, it provides instructions for installation.

    :raises EnvironmentError: If Gradio package is not installed in the environment.
    """
    if gr is None:
        raise EnvironmentError(f'Gradio required for launching webui-based demo.\n'
                               f'Please install it with `pip install dghs-imgutils[demo]`.')


FMT_UNSET = object()


class MultiLabelTIMMModel:
    def __init__(self, repo_id: str, hf_token: Optional[str] = None, category_names: Dict[Any, str] = None):
        self.repo_id = repo_id
        self._model = None
        self._df_tags = None
        self._preprocess = None
        self._default_category_thresholds = None
        self._hf_token = hf_token
        self._lock = Lock()
        self._category_names = category_names or {}
        self._name_to_categories = None

    def _get_hf_token(self) -> Optional[str]:
        """
        Retrieve the Hugging Face authentication token.

        Checks both instance variable and environment for token presence.

        :return: Authentication token if available
        :rtype: Optional[str]
        """
        return self._hf_token or os.environ.get('HF_TOKEN')

    def _open_model(self):
        with self._lock:
            if self._model is None:
                self._model = open_onnx_model(hf_hub_download(
                    repo_id=self.repo_id,
                    repo_type='model',
                    filename='model.onnx',
                    token=self._get_hf_token(),
                ))

        return self._model

    def _open_tags(self):
        with self._lock:
            if self._df_tags is None:
                self._df_tags = pd.read_csv(hf_hub_download(
                    repo_id=self.repo_id,
                    repo_type='model',
                    filename='selected_tags.csv',
                    token=self._get_hf_token(),
                ))
                self._name_to_categories = {}
                for category in sorted(set(self._df_tags['category'])):
                    if not self._category_names.get(category):
                        self._category_names[category] = f'category_{category}'
                    self._name_to_categories[self._category_names[category]] = category

        return self._df_tags

    def _open_preprocess(self):
        with self._lock:
            if self._preprocess is None:
                with open(hf_hub_download(
                        repo_id=self.repo_id,
                        repo_type='model',
                        filename='preprocess.json'
                ), 'r') as f:
                    data_ = json.load(f)
                    test_trans = create_pillow_transforms(data_['test'])
                    val_trans = create_pillow_transforms(data_['val'])
                    self._preprocess = val_trans, test_trans

        return self._preprocess

    def _open_default_category_thresholds(self):
        with self._lock:
            if self._default_category_thresholds is None:
                try:
                    df_category_thresholds = pd.read_csv(hf_hub_download(
                        repo_id=self.repo_id,
                        repo_type='model',
                        filename='thresholds.csv'
                    ))
                except (EntryNotFoundError,):
                    self._default_category_thresholds = {}
                else:
                    self._default_category_thresholds = {}
                    for item in df_category_thresholds.to_dict('records'):
                        if item['category'] not in self._default_category_thresholds:
                            self._default_category_thresholds[item['category']] = item['threshold']

        return self._default_category_thresholds

    def _raw_predict(self, image: ImageTyping, preprocessor: Literal['test', 'val'] = 'test'):
        image = load_image(image, force_background='white', mode='RGB')
        model = self._open_model()

        val_trans, test_trans = self._open_preprocess()
        if preprocessor == 'test':
            trans = test_trans
        elif preprocessor == 'val':
            trans = val_trans
        else:
            raise ValueError(f'Unknown processor - {preprocessor!r}.')

        input_ = trans(image)[None, ...]
        output_names = [output.name for output in model.get_outputs()]
        output_values = model.run(output_names, {'input': input_})
        return {name: value[0] for name, value in zip(output_names, output_values)}

    def predict(self, image: ImageTyping, preprocessor: Literal['test', 'val'] = 'test',
                thresholds: Union[float, Dict[Any, float]] = None, use_tag_thresholds: bool = False,
                fmt=FMT_UNSET):
        df_tags = self._open_tags()
        values = self._raw_predict(image, preprocessor=preprocessor)
        prediction = values.pop('prediction')

        if fmt is FMT_UNSET:
            fmt = tuple(self._category_names[category] for category in sorted(set(df_tags['category'].tolist())))

        default_category_thresholds = self._open_default_category_thresholds()
        if 'best_threshold' in self._df_tags:
            default_tag_thresholds = self._df_tags['best_threshold']
        else:
            default_tag_thresholds = None
        for category in sorted(set(df_tags['category'].tolist())):
            mask = df_tags['category'] == category
            tag_names = df_tags['name'][mask]
            category_pred = prediction[mask]

            if isinstance(thresholds, float):
                category_threshold = thresholds
            elif isinstance(thresholds, dict) and \
                    (category in thresholds or self._category_names[category] in thresholds):
                if category in thresholds:
                    category_threshold = thresholds[category]
                elif self._category_names[category] in thresholds:
                    category_threshold = thresholds[self._category_names[category]]
                else:
                    assert False, 'Should not reach this line'  # pragma: no cover
            elif use_tag_thresholds and default_tag_thresholds is not None:
                category_threshold = default_tag_thresholds[mask]
            else:
                if use_tag_thresholds:
                    warnings.warn(f'Tag thresholds not supported in model {self.repo_id!r}.')
                if category in default_category_thresholds:
                    category_threshold = default_category_thresholds[category]
                else:
                    category_threshold = 0.4

            mask = category_pred >= category_threshold
            tag_names = tag_names[mask].tolist()
            category_pred = category_pred[mask].tolist()
            values[self._category_names[category]] = \
                dict(sorted(zip(tag_names, category_pred), key=lambda x: (-x[1], x[0])))

        return vreplace(fmt, values)

    def make_ui(self, default_thresholds: Union[float, Dict[Any, float]] = None):
        _check_gradio_env()
        df_tags = self._open_tags()
        default_category_thresholds = self._open_default_category_thresholds()
        allow_use_tag_thresholds = 'best_threshold' in self._df_tags

        with gr.Row():
            with gr.Column():
                with gr.Row():
                    gr_input_image = gr.Image(type='pil', label='Original Image')
                with gr.Row():
                    gr_use_tag_thresholds = gr.Checkbox(
                        value=allow_use_tag_thresholds,
                        label='Use Tag-Level Thresholds (Category Thresholds Will Be Ignored When Enabled!!!)',
                        interactive=allow_use_tag_thresholds,
                        visible=allow_use_tag_thresholds,
                    )
                with gr.Row():
                    gr_thresholds = []
                    for category in sorted(set(df_tags['category'].tolist())):
                        if isinstance(default_thresholds, float):
                            category_threshold = default_thresholds
                        elif isinstance(default_thresholds, dict) and \
                                (category in default_thresholds or self._category_names[
                                    category] in default_thresholds):
                            if category in default_thresholds:
                                category_threshold = default_thresholds[category]
                            elif self._category_names[category] in default_thresholds:
                                category_threshold = default_thresholds[self._category_names[category]]
                            else:
                                assert False, 'Should not reach this line'  # pragma: no cover
                        elif category in default_category_thresholds:
                            category_threshold = default_category_thresholds[category]
                        else:
                            category_threshold = 0.4

                        gr_cate_threshold = gr.Slider(
                            minimum=0.0,
                            maximum=1.0,
                            value=category_threshold,
                            step=0.001,
                            label=f'Threshold for {titleize(self._category_names[category])}',
                        )
                        gr_thresholds.append(gr_cate_threshold)

                with gr.Row():
                    gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                with gr.Tabs():
                    gr_preds = []
                    for category in sorted(set(df_tags['category'].tolist())):
                        with gr.Tab(f'{titleize(self._category_names[category])}'):
                            gr_cate_label = gr.Label(f'{titleize(self._category_names[category])} Prediction')
                            gr_preds.append(gr_cate_label)

            def _fn_submit(image, _use_tag_thresholds, *thresholds):
                if _use_tag_thresholds:
                    _ths = None
                else:
                    _ths = {
                        category: cate_ths
                        for category, cate_ths in zip(sorted(set(df_tags['category'].tolist())), thresholds)
                    }

                fmt = tuple(self._category_names[category] for category in sorted(set(df_tags['category'].tolist())))
                res = self.predict(
                    image=image,
                    thresholds=_ths,
                    use_tag_thresholds=_use_tag_thresholds,
                    fmt=fmt,
                )
                return res[0] if len(fmt) == 1 else res

            gr_submit.click(
                fn=_fn_submit,
                inputs=[gr_input_image, gr_use_tag_thresholds, *gr_thresholds],
                outputs=[*gr_preds]
            )

    def launch_demo(self, default_thresholds: Union[float, Dict[Any, float]] = None,
                    server_name: Optional[str] = None, server_port: Optional[int] = None, **kwargs):
        _check_gradio_env()
        with gr.Blocks() as demo:
            with gr.Row():
                with gr.Column():
                    repo_url = hf_hub_repo_url(repo_id=self.repo_id, repo_type='model')
                    gr.HTML(f'<h2 style="text-align: center;">Tagger Demo For {self.repo_id}</h2>')
                    gr.Markdown(f'This is the quick demo for tagger model [{self.repo_id}]({repo_url}). '
                                f'Powered by `dghs-imgutils`\'s quick demo module.')

            with gr.Row():
                self.make_ui(default_thresholds=default_thresholds)

        demo.launch(
            server_name=server_name,
            server_port=server_port,
            **kwargs,
        )


@ts_lru_cache()
def _open_models_for_repo_id(repo_id: str, category_names: Optional[Tuple[Tuple[Any, str], ...]] = None,
                             hf_token: Optional[str] = None) \
        -> MultiLabelTIMMModel:
    return MultiLabelTIMMModel(
        repo_id=repo_id,
        hf_token=hf_token,
        category_names=dict(category_names or []),
    )


def multilabel_predict(image: ImageTyping, repo_id: str, category_names: Dict[Any, str] = None,
                       preprocessor: Literal['test', 'val'] = 'test',
                       thresholds: Union[float, Dict[Any, float]] = None, use_tag_thresholds: bool = False,
                       fmt=FMT_UNSET, hf_token: Optional[str] = None):
    model = _open_models_for_repo_id(
        repo_id=repo_id,
        category_names=tuple((key, value) for key, value in sorted((category_names or {}).items())),
        hf_token=hf_token,
    )
    return model.predict(
        image=image,
        preprocessor=preprocessor,
        thresholds=thresholds,
        use_tag_thresholds=use_tag_thresholds,
        fmt=fmt,
    )