Commit 313a6536 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix known bugs

parent 6c85e626
Loading
Loading
Loading
Loading
+4 −5
Original line number Diff line number Diff line
@@ -63,6 +63,8 @@ def _open_default_category_thresholds(model_name: str) -> Union[Dict[int, float]
    :return: Dictionary mapping category IDs to threshold values
    :rtype: dict
    """
    _default_category_thresholds: Dict[int, float] = {}
    _category_names: Dict[int, str] = {}
    try:
        df_category_thresholds = pd.read_csv(hf_hub_download(
            repo_id=_get_repo_id(model_name),
@@ -70,11 +72,8 @@ def _open_default_category_thresholds(model_name: str) -> Union[Dict[int, float]
            filename='thresholds.csv'
        ), keep_default_na=False)
    except (EntryNotFoundError,):
        _default_category_thresholds = {}
        _category_names = {}
        pass
    else:
        _default_category_thresholds = {}
        _category_names = {}
        for item in df_category_thresholds.to_dict('records'):
            if item['category'] not in _default_category_thresholds:
                _default_category_thresholds[item['category']] = item['threshold']
@@ -147,7 +146,7 @@ def get_pixai_tags(image: ImageTyping, model_name: str = 'v0.9',
    prediction = values['prediction']
    tags = {}

    default_category_thresholds, category_names = _open_default_category_thresholds()
    default_category_thresholds, category_names = _open_default_category_thresholds(model_name=model_name)
    if fmt is FMT_UNSET:
        fmt = tuple(category_names[category] for category in sorted(set(df_tags['category'].tolist())))