Loading imgutils/tagging/pixai.py +4 −5 Original line number Diff line number Diff line Loading @@ -63,6 +63,8 @@ def _open_default_category_thresholds(model_name: str) -> Union[Dict[int, float] :return: Dictionary mapping category IDs to threshold values :rtype: dict """ _default_category_thresholds: Dict[int, float] = {} _category_names: Dict[int, str] = {} try: df_category_thresholds = pd.read_csv(hf_hub_download( repo_id=_get_repo_id(model_name), Loading @@ -70,11 +72,8 @@ def _open_default_category_thresholds(model_name: str) -> Union[Dict[int, float] filename='thresholds.csv' ), keep_default_na=False) except (EntryNotFoundError,): _default_category_thresholds = {} _category_names = {} pass else: _default_category_thresholds = {} _category_names = {} for item in df_category_thresholds.to_dict('records'): if item['category'] not in _default_category_thresholds: _default_category_thresholds[item['category']] = item['threshold'] Loading Loading @@ -147,7 +146,7 @@ def get_pixai_tags(image: ImageTyping, model_name: str = 'v0.9', prediction = values['prediction'] tags = {} default_category_thresholds, category_names = _open_default_category_thresholds() default_category_thresholds, category_names = _open_default_category_thresholds(model_name=model_name) if fmt is FMT_UNSET: fmt = tuple(category_names[category] for category in sorted(set(df_tags['category'].tolist()))) Loading Loading
imgutils/tagging/pixai.py +4 −5 Original line number Diff line number Diff line Loading @@ -63,6 +63,8 @@ def _open_default_category_thresholds(model_name: str) -> Union[Dict[int, float] :return: Dictionary mapping category IDs to threshold values :rtype: dict """ _default_category_thresholds: Dict[int, float] = {} _category_names: Dict[int, str] = {} try: df_category_thresholds = pd.read_csv(hf_hub_download( repo_id=_get_repo_id(model_name), Loading @@ -70,11 +72,8 @@ def _open_default_category_thresholds(model_name: str) -> Union[Dict[int, float] filename='thresholds.csv' ), keep_default_na=False) except (EntryNotFoundError,): _default_category_thresholds = {} _category_names = {} pass else: _default_category_thresholds = {} _category_names = {} for item in df_category_thresholds.to_dict('records'): if item['category'] not in _default_category_thresholds: _default_category_thresholds[item['category']] = item['threshold'] Loading Loading @@ -147,7 +146,7 @@ def get_pixai_tags(image: ImageTyping, model_name: str = 'v0.9', prediction = values['prediction'] tags = {} default_category_thresholds, category_names = _open_default_category_thresholds() default_category_thresholds, category_names = _open_default_category_thresholds(model_name=model_name) if fmt is FMT_UNSET: fmt = tuple(category_names[category] for category in sorted(set(df_tags['category'].tolist()))) Loading