Commit 8ab17fe2 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add category names

parent 4145012c
Loading
Loading
Loading
Loading
+19 −22
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ The models are expected to be stored on Hugging Face Hub with specific files:
- selected_tags.csv: CSV file containing tag information and categories
- preprocess.json: JSON configuration for image preprocessing
- thresholds.csv: Optional CSV file with recommended thresholds
- categories.json: Category ID and name mapping json file.

This module is designed to work with multi-label classification tasks where images can
belong to multiple categories and have multiple tags within each category.
@@ -25,7 +26,7 @@ import json
import os
import warnings
from threading import Lock
from typing import Optional, Literal, Dict, Any, Union, Tuple
from typing import Optional, Literal, Dict, Any, Union

import pandas as pd
from hbutils.string import titleize
@@ -77,11 +78,9 @@ class MultiLabelTIMMModel:
    :type repo_id: str
    :param hf_token: Optional Hugging Face authentication token for private repositories
    :type hf_token: Optional[str]
    :param category_names: Optional mapping of category IDs to display names
    :type category_names: Dict[Any, str]
    """

    def __init__(self, repo_id: str, hf_token: Optional[str] = None, category_names: Dict[Any, str] = None):
    def __init__(self, repo_id: str, hf_token: Optional[str] = None):
        """
        Initialize a MultiLabelTIMMModel.

@@ -89,8 +88,6 @@ class MultiLabelTIMMModel:
        :type repo_id: str
        :param hf_token: Optional Hugging Face authentication token for private repositories
        :type hf_token: Optional[str]
        :param category_names: Optional mapping of category IDs to display names
        :type category_names: Dict[Any, str]
        """
        self.repo_id = repo_id
        self._model = None
@@ -99,7 +96,7 @@ class MultiLabelTIMMModel:
        self._default_category_thresholds = None
        self._hf_token = hf_token
        self._lock = Lock()
        self._category_names = category_names or {}
        self._category_names = {}
        self._name_to_categories = None

    def _get_hf_token(self) -> Optional[str]:
@@ -146,10 +143,17 @@ class MultiLabelTIMMModel:
                    filename='selected_tags.csv',
                    token=self._get_hf_token(),
                ))

                with open(hf_hub_download(
                        repo_id=self.repo_id,
                        repo_type='model',
                        filename='categories.json',
                        token=self._get_hf_token(),
                ), 'r') as f:
                    d_category_names = {cate_item['category']: cate_item['name'] for cate_item in json.load(f)}
                    self._name_to_categories = {}
                    for category in sorted(set(self._df_tags['category'])):
                    if not self._category_names.get(category):
                        self._category_names[category] = f'category_{category}'
                        self._category_names[category] = d_category_names[category]
                        self._name_to_categories[self._category_names[category]] = category

        return self._df_tags
@@ -407,7 +411,7 @@ class MultiLabelTIMMModel:
                )
                with io.StringIO() as sf:
                    for category, res_item in zip(sorted(set(df_tags['category'].tolist())), res):
                        print(f'# {self._category_names[category]} (#{category}):', file=sf)
                        print(f'# {self._category_names[category]} (#{category})', file=sf)
                        print(', '.join(res_item.keys()), file=sf)
                        print('', file=sf)

@@ -462,16 +466,13 @@ class MultiLabelTIMMModel:


@ts_lru_cache()
def _open_models_for_repo_id(repo_id: str, category_names: Optional[Tuple[Tuple[Any, str], ...]] = None,
                             hf_token: Optional[str] = None) \
def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) \
        -> MultiLabelTIMMModel:
    """
    Open and cache a MultiLabelTIMMModel for a given repository ID.

    :param repo_id: The Hugging Face Hub repository ID
    :type repo_id: str
    :param category_names: Optional tuple of (category_id, name) pairs for category naming
    :type category_names: Optional[Tuple[Tuple[Any, str], ...]]
    :param hf_token: Optional Hugging Face authentication token
    :type hf_token: Optional[str]

@@ -481,11 +482,10 @@ def _open_models_for_repo_id(repo_id: str, category_names: Optional[Tuple[Tuple[
    return MultiLabelTIMMModel(
        repo_id=repo_id,
        hf_token=hf_token,
        category_names=dict(category_names or []),
    )


def multilabel_timm_predict(image: ImageTyping, repo_id: str, category_names: Dict[Any, str] = None,
def multilabel_timm_predict(image: ImageTyping, repo_id: str,
                            preprocessor: Literal['test', 'val'] = 'test',
                            thresholds: Union[float, Dict[Any, float]] = None, use_tag_thresholds: bool = True,
                            fmt=FMT_UNSET, hf_token: Optional[str] = None):
@@ -499,8 +499,6 @@ def multilabel_timm_predict(image: ImageTyping, repo_id: str, category_names: Di
    :type image: ImageTyping
    :param repo_id: The Hugging Face Hub repository ID containing the model
    :type repo_id: str
    :param category_names: Optional mapping of category IDs to display names
    :type category_names: Dict[Any, str]
    :param preprocessor: Which preprocessor to use ('test' or 'val')
    :type preprocessor: Literal['test', 'val']
    :param thresholds: Threshold values for tag confidence. Can be a single float applied to all categories
@@ -543,7 +541,6 @@ def multilabel_timm_predict(image: ImageTyping, repo_id: str, category_names: Di
    """
    model = _open_models_for_repo_id(
        repo_id=repo_id,
        category_names=tuple((key, value) for key, value in sorted((category_names or {}).items())),
        hf_token=hf_token,
    )
    return model.predict(
+0 −6
Original line number Diff line number Diff line
@@ -19,7 +19,6 @@ class TestGenericMultilabelTIMM:
        general, character, rating = multilabel_timm_predict(
            get_testfile('nude_girl.png'),
            repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
            category_names={0: 'general', 4: 'character', 9: 'rating'},
            fmt=('general', 'character', 'rating'),
        )
        assert general == pytest.approx({
@@ -64,7 +63,6 @@ class TestGenericMultilabelTIMM:
            get_testfile('nude_girl.png'),
            repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
            thresholds={'general': 0.3},
            category_names={0: 'general', 4: 'character', 9: 'rating'},
            fmt=('general', 'character', 'rating'),
        )
        assert general == pytest.approx({
@@ -87,7 +85,6 @@ class TestGenericMultilabelTIMM:
            get_testfile('nude_girl.png'),
            repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
            thresholds={0: 0.3},
            category_names={0: 'general', 4: 'character', 9: 'rating'},
            fmt=('general', 'character', 'rating'),
        )
        assert general == pytest.approx({
@@ -110,7 +107,6 @@ class TestGenericMultilabelTIMM:
            get_testfile('nude_girl.png'),
            repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
            thresholds=0.3,
            category_names={0: 'general', 4: 'character', 9: 'rating'},
            fmt=('general', 'character', 'rating'),
        )
        assert general == pytest.approx({
@@ -133,7 +129,6 @@ class TestGenericMultilabelTIMM:
            get_testfile('nude_girl.png'),
            repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
            thresholds={0: 0.3},
            category_names={0: 'general', 4: 'character', 9: 'rating'},
            fmt=('general', 'character', 'rating'),
            use_tag_thresholds=False,
        )
@@ -157,7 +152,6 @@ class TestGenericMultilabelTIMM:
            get_testfile('nude_girl.png'),
            repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
            thresholds={0: 0.3},
            category_names={0: 'general', 4: 'character', 9: 'rating'},
            fmt=('general', 'character', 'rating'),
            use_tag_thresholds=False,
            preprocessor='val',