Commit fa197576 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix tags

parent f0f1469a
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -153,6 +153,7 @@ class MultiLabelTIMMModel:
        df_tags = self._open_tags()
        values = self._raw_predict(image, preprocessor=preprocessor)
        prediction = values['prediction']
        tags = {}

        if fmt is FMT_UNSET:
            fmt = tuple(self._category_names[category] for category in sorted(set(df_tags['category'].tolist())))
@@ -190,9 +191,11 @@ class MultiLabelTIMMModel:
            mask = category_pred >= category_threshold
            tag_names = tag_names[mask].tolist()
            category_pred = category_pred[mask].tolist()
            values[self._category_names[category]] = \
                dict(sorted(zip(tag_names, category_pred), key=lambda x: (-x[1], x[0])))
            cate_tags = dict(sorted(zip(tag_names, category_pred), key=lambda x: (-x[1], x[0])))
            values[self._category_names[category]] = cate_tags
            tags.update(cate_tags)

        values['tag'] = tags
        return vreplace(fmt, values)

    def make_ui(self, default_thresholds: Union[float, Dict[Any, float]] = None,