Loading imgutils/generic/multilabel_timm.py +19 −22 Original line number Diff line number Diff line Loading @@ -15,6 +15,7 @@ The models are expected to be stored on Hugging Face Hub with specific files: - selected_tags.csv: CSV file containing tag information and categories - preprocess.json: JSON configuration for image preprocessing - thresholds.csv: Optional CSV file with recommended thresholds - categories.json: Category ID and name mapping json file. This module is designed to work with multi-label classification tasks where images can belong to multiple categories and have multiple tags within each category. Loading @@ -25,7 +26,7 @@ import json import os import warnings from threading import Lock from typing import Optional, Literal, Dict, Any, Union, Tuple from typing import Optional, Literal, Dict, Any, Union import pandas as pd from hbutils.string import titleize Loading Loading @@ -77,11 +78,9 @@ class MultiLabelTIMMModel: :type repo_id: str :param hf_token: Optional Hugging Face authentication token for private repositories :type hf_token: Optional[str] :param category_names: Optional mapping of category IDs to display names :type category_names: Dict[Any, str] """ def __init__(self, repo_id: str, hf_token: Optional[str] = None, category_names: Dict[Any, str] = None): def __init__(self, repo_id: str, hf_token: Optional[str] = None): """ Initialize a MultiLabelTIMMModel. Loading @@ -89,8 +88,6 @@ class MultiLabelTIMMModel: :type repo_id: str :param hf_token: Optional Hugging Face authentication token for private repositories :type hf_token: Optional[str] :param category_names: Optional mapping of category IDs to display names :type category_names: Dict[Any, str] """ self.repo_id = repo_id self._model = None Loading @@ -99,7 +96,7 @@ class MultiLabelTIMMModel: self._default_category_thresholds = None self._hf_token = hf_token self._lock = Lock() self._category_names = category_names or {} self._category_names = {} self._name_to_categories = None def _get_hf_token(self) -> Optional[str]: Loading Loading @@ -146,10 +143,17 @@ class MultiLabelTIMMModel: filename='selected_tags.csv', token=self._get_hf_token(), )) with open(hf_hub_download( repo_id=self.repo_id, repo_type='model', filename='categories.json', token=self._get_hf_token(), ), 'r') as f: d_category_names = {cate_item['category']: cate_item['name'] for cate_item in json.load(f)} self._name_to_categories = {} for category in sorted(set(self._df_tags['category'])): if not self._category_names.get(category): self._category_names[category] = f'category_{category}' self._category_names[category] = d_category_names[category] self._name_to_categories[self._category_names[category]] = category return self._df_tags Loading Loading @@ -407,7 +411,7 @@ class MultiLabelTIMMModel: ) with io.StringIO() as sf: for category, res_item in zip(sorted(set(df_tags['category'].tolist())), res): print(f'# {self._category_names[category]} (#{category}):', file=sf) print(f'# {self._category_names[category]} (#{category})', file=sf) print(', '.join(res_item.keys()), file=sf) print('', file=sf) Loading Loading @@ -462,16 +466,13 @@ class MultiLabelTIMMModel: @ts_lru_cache() def _open_models_for_repo_id(repo_id: str, category_names: Optional[Tuple[Tuple[Any, str], ...]] = None, hf_token: Optional[str] = None) \ def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) \ -> MultiLabelTIMMModel: """ Open and cache a MultiLabelTIMMModel for a given repository ID. :param repo_id: The Hugging Face Hub repository ID :type repo_id: str :param category_names: Optional tuple of (category_id, name) pairs for category naming :type category_names: Optional[Tuple[Tuple[Any, str], ...]] :param hf_token: Optional Hugging Face authentication token :type hf_token: Optional[str] Loading @@ -481,11 +482,10 @@ def _open_models_for_repo_id(repo_id: str, category_names: Optional[Tuple[Tuple[ return MultiLabelTIMMModel( repo_id=repo_id, hf_token=hf_token, category_names=dict(category_names or []), ) def multilabel_timm_predict(image: ImageTyping, repo_id: str, category_names: Dict[Any, str] = None, def multilabel_timm_predict(image: ImageTyping, repo_id: str, preprocessor: Literal['test', 'val'] = 'test', thresholds: Union[float, Dict[Any, float]] = None, use_tag_thresholds: bool = True, fmt=FMT_UNSET, hf_token: Optional[str] = None): Loading @@ -499,8 +499,6 @@ def multilabel_timm_predict(image: ImageTyping, repo_id: str, category_names: Di :type image: ImageTyping :param repo_id: The Hugging Face Hub repository ID containing the model :type repo_id: str :param category_names: Optional mapping of category IDs to display names :type category_names: Dict[Any, str] :param preprocessor: Which preprocessor to use ('test' or 'val') :type preprocessor: Literal['test', 'val'] :param thresholds: Threshold values for tag confidence. Can be a single float applied to all categories Loading Loading @@ -543,7 +541,6 @@ def multilabel_timm_predict(image: ImageTyping, repo_id: str, category_names: Di """ model = _open_models_for_repo_id( repo_id=repo_id, category_names=tuple((key, value) for key, value in sorted((category_names or {}).items())), hf_token=hf_token, ) return model.predict( Loading test/generic/test_multilabel_timm.py +0 −6 Original line number Diff line number Diff line Loading @@ -19,7 +19,6 @@ class TestGenericMultilabelTIMM: general, character, rating = multilabel_timm_predict( get_testfile('nude_girl.png'), repo_id='animetimm/mobilenetv3_large_150d.dbv4-full', category_names={0: 'general', 4: 'character', 9: 'rating'}, fmt=('general', 'character', 'rating'), ) assert general == pytest.approx({ Loading Loading @@ -64,7 +63,6 @@ class TestGenericMultilabelTIMM: get_testfile('nude_girl.png'), repo_id='animetimm/mobilenetv3_large_150d.dbv4-full', thresholds={'general': 0.3}, category_names={0: 'general', 4: 'character', 9: 'rating'}, fmt=('general', 'character', 'rating'), ) assert general == pytest.approx({ Loading @@ -87,7 +85,6 @@ class TestGenericMultilabelTIMM: get_testfile('nude_girl.png'), repo_id='animetimm/mobilenetv3_large_150d.dbv4-full', thresholds={0: 0.3}, category_names={0: 'general', 4: 'character', 9: 'rating'}, fmt=('general', 'character', 'rating'), ) assert general == pytest.approx({ Loading @@ -110,7 +107,6 @@ class TestGenericMultilabelTIMM: get_testfile('nude_girl.png'), repo_id='animetimm/mobilenetv3_large_150d.dbv4-full', thresholds=0.3, category_names={0: 'general', 4: 'character', 9: 'rating'}, fmt=('general', 'character', 'rating'), ) assert general == pytest.approx({ Loading @@ -133,7 +129,6 @@ class TestGenericMultilabelTIMM: get_testfile('nude_girl.png'), repo_id='animetimm/mobilenetv3_large_150d.dbv4-full', thresholds={0: 0.3}, category_names={0: 'general', 4: 'character', 9: 'rating'}, fmt=('general', 'character', 'rating'), use_tag_thresholds=False, ) Loading @@ -157,7 +152,6 @@ class TestGenericMultilabelTIMM: get_testfile('nude_girl.png'), repo_id='animetimm/mobilenetv3_large_150d.dbv4-full', thresholds={0: 0.3}, category_names={0: 'general', 4: 'character', 9: 'rating'}, fmt=('general', 'character', 'rating'), use_tag_thresholds=False, preprocessor='val', Loading Loading
imgutils/generic/multilabel_timm.py +19 −22 Original line number Diff line number Diff line Loading @@ -15,6 +15,7 @@ The models are expected to be stored on Hugging Face Hub with specific files: - selected_tags.csv: CSV file containing tag information and categories - preprocess.json: JSON configuration for image preprocessing - thresholds.csv: Optional CSV file with recommended thresholds - categories.json: Category ID and name mapping json file. This module is designed to work with multi-label classification tasks where images can belong to multiple categories and have multiple tags within each category. Loading @@ -25,7 +26,7 @@ import json import os import warnings from threading import Lock from typing import Optional, Literal, Dict, Any, Union, Tuple from typing import Optional, Literal, Dict, Any, Union import pandas as pd from hbutils.string import titleize Loading Loading @@ -77,11 +78,9 @@ class MultiLabelTIMMModel: :type repo_id: str :param hf_token: Optional Hugging Face authentication token for private repositories :type hf_token: Optional[str] :param category_names: Optional mapping of category IDs to display names :type category_names: Dict[Any, str] """ def __init__(self, repo_id: str, hf_token: Optional[str] = None, category_names: Dict[Any, str] = None): def __init__(self, repo_id: str, hf_token: Optional[str] = None): """ Initialize a MultiLabelTIMMModel. Loading @@ -89,8 +88,6 @@ class MultiLabelTIMMModel: :type repo_id: str :param hf_token: Optional Hugging Face authentication token for private repositories :type hf_token: Optional[str] :param category_names: Optional mapping of category IDs to display names :type category_names: Dict[Any, str] """ self.repo_id = repo_id self._model = None Loading @@ -99,7 +96,7 @@ class MultiLabelTIMMModel: self._default_category_thresholds = None self._hf_token = hf_token self._lock = Lock() self._category_names = category_names or {} self._category_names = {} self._name_to_categories = None def _get_hf_token(self) -> Optional[str]: Loading Loading @@ -146,10 +143,17 @@ class MultiLabelTIMMModel: filename='selected_tags.csv', token=self._get_hf_token(), )) with open(hf_hub_download( repo_id=self.repo_id, repo_type='model', filename='categories.json', token=self._get_hf_token(), ), 'r') as f: d_category_names = {cate_item['category']: cate_item['name'] for cate_item in json.load(f)} self._name_to_categories = {} for category in sorted(set(self._df_tags['category'])): if not self._category_names.get(category): self._category_names[category] = f'category_{category}' self._category_names[category] = d_category_names[category] self._name_to_categories[self._category_names[category]] = category return self._df_tags Loading Loading @@ -407,7 +411,7 @@ class MultiLabelTIMMModel: ) with io.StringIO() as sf: for category, res_item in zip(sorted(set(df_tags['category'].tolist())), res): print(f'# {self._category_names[category]} (#{category}):', file=sf) print(f'# {self._category_names[category]} (#{category})', file=sf) print(', '.join(res_item.keys()), file=sf) print('', file=sf) Loading Loading @@ -462,16 +466,13 @@ class MultiLabelTIMMModel: @ts_lru_cache() def _open_models_for_repo_id(repo_id: str, category_names: Optional[Tuple[Tuple[Any, str], ...]] = None, hf_token: Optional[str] = None) \ def _open_models_for_repo_id(repo_id: str, hf_token: Optional[str] = None) \ -> MultiLabelTIMMModel: """ Open and cache a MultiLabelTIMMModel for a given repository ID. :param repo_id: The Hugging Face Hub repository ID :type repo_id: str :param category_names: Optional tuple of (category_id, name) pairs for category naming :type category_names: Optional[Tuple[Tuple[Any, str], ...]] :param hf_token: Optional Hugging Face authentication token :type hf_token: Optional[str] Loading @@ -481,11 +482,10 @@ def _open_models_for_repo_id(repo_id: str, category_names: Optional[Tuple[Tuple[ return MultiLabelTIMMModel( repo_id=repo_id, hf_token=hf_token, category_names=dict(category_names or []), ) def multilabel_timm_predict(image: ImageTyping, repo_id: str, category_names: Dict[Any, str] = None, def multilabel_timm_predict(image: ImageTyping, repo_id: str, preprocessor: Literal['test', 'val'] = 'test', thresholds: Union[float, Dict[Any, float]] = None, use_tag_thresholds: bool = True, fmt=FMT_UNSET, hf_token: Optional[str] = None): Loading @@ -499,8 +499,6 @@ def multilabel_timm_predict(image: ImageTyping, repo_id: str, category_names: Di :type image: ImageTyping :param repo_id: The Hugging Face Hub repository ID containing the model :type repo_id: str :param category_names: Optional mapping of category IDs to display names :type category_names: Dict[Any, str] :param preprocessor: Which preprocessor to use ('test' or 'val') :type preprocessor: Literal['test', 'val'] :param thresholds: Threshold values for tag confidence. Can be a single float applied to all categories Loading Loading @@ -543,7 +541,6 @@ def multilabel_timm_predict(image: ImageTyping, repo_id: str, category_names: Di """ model = _open_models_for_repo_id( repo_id=repo_id, category_names=tuple((key, value) for key, value in sorted((category_names or {}).items())), hf_token=hf_token, ) return model.predict( Loading
test/generic/test_multilabel_timm.py +0 −6 Original line number Diff line number Diff line Loading @@ -19,7 +19,6 @@ class TestGenericMultilabelTIMM: general, character, rating = multilabel_timm_predict( get_testfile('nude_girl.png'), repo_id='animetimm/mobilenetv3_large_150d.dbv4-full', category_names={0: 'general', 4: 'character', 9: 'rating'}, fmt=('general', 'character', 'rating'), ) assert general == pytest.approx({ Loading Loading @@ -64,7 +63,6 @@ class TestGenericMultilabelTIMM: get_testfile('nude_girl.png'), repo_id='animetimm/mobilenetv3_large_150d.dbv4-full', thresholds={'general': 0.3}, category_names={0: 'general', 4: 'character', 9: 'rating'}, fmt=('general', 'character', 'rating'), ) assert general == pytest.approx({ Loading @@ -87,7 +85,6 @@ class TestGenericMultilabelTIMM: get_testfile('nude_girl.png'), repo_id='animetimm/mobilenetv3_large_150d.dbv4-full', thresholds={0: 0.3}, category_names={0: 'general', 4: 'character', 9: 'rating'}, fmt=('general', 'character', 'rating'), ) assert general == pytest.approx({ Loading @@ -110,7 +107,6 @@ class TestGenericMultilabelTIMM: get_testfile('nude_girl.png'), repo_id='animetimm/mobilenetv3_large_150d.dbv4-full', thresholds=0.3, category_names={0: 'general', 4: 'character', 9: 'rating'}, fmt=('general', 'character', 'rating'), ) assert general == pytest.approx({ Loading @@ -133,7 +129,6 @@ class TestGenericMultilabelTIMM: get_testfile('nude_girl.png'), repo_id='animetimm/mobilenetv3_large_150d.dbv4-full', thresholds={0: 0.3}, category_names={0: 'general', 4: 'character', 9: 'rating'}, fmt=('general', 'character', 'rating'), use_tag_thresholds=False, ) Loading @@ -157,7 +152,6 @@ class TestGenericMultilabelTIMM: get_testfile('nude_girl.png'), repo_id='animetimm/mobilenetv3_large_150d.dbv4-full', thresholds={0: 0.3}, category_names={0: 'general', 4: 'character', 9: 'rating'}, fmt=('general', 'character', 'rating'), use_tag_thresholds=False, preprocessor='val', Loading