Commit 5f0835a1 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add a nice docstring for multilabel_timm_predict function

parent b4c62683
Loading
Loading
Loading
Loading
+59 −14
Original line number Diff line number Diff line
@@ -268,6 +268,8 @@ class MultiLabelTIMMModel:
            - ``prediction``: a 1-dim prediction result of image

            You can extract specific category predictions or all tags based on your needs.

        For more details see documentation of :func:`multilabel_timm_predict`.
        """
        df_tags = self._open_tags()
        values = self._raw_predict(image, preprocessor=preprocessor)
@@ -515,7 +517,8 @@ def multilabel_timm_predict(image: ImageTyping, repo_id: str,
    :return: Formatted prediction results according to the fmt parameter
    :rtype: Any

    Example::
    Example:
        Here are some images for example

        .. image:: multilabel_demo.plot.py.svg
           :align: center
@@ -547,8 +550,29 @@ def multilabel_timm_predict(image: ImageTyping, repo_id: str,
    .. note::
        For different models, the default format is different. That depends on the categories that model supported.

        For example, for model `animetimm/mobilenetv3_large_150d.dbv4-full-witha <https://huggingface.co/animetimm/mobilenetv3_large_150d.dbv4-full-witha>`_

        >>> from imgutils.generic import multilabel_timm_predict
        >>>
        >>> general, artist, character, rating = multilabel_timm_predict(
        ...     'skadi.jpg',
        ...     repo_id='animetimm/mobilenetv3_large_150d.dbv4-full-witha',
        ... )
        >>> general
        {'1girl': 0.9938606023788452, 'long_hair': 0.9691187143325806, 'red_eyes': 0.9463587403297424, 'solo': 0.944723904132843, 'navel': 0.9439248442649841, 'breasts': 0.9335891008377075, 'sportswear': 0.8865424394607544, 'shorts': 0.8601726293563843, 'very_long_hair': 0.8445472717285156, 'outdoors': 0.83197021484375, 'midriff': 0.8274217247962952, 'shirt': 0.8188955783843994, 'short_sleeves': 0.8183804750442505, 'crop_top': 0.8089936971664429, 'gloves': 0.8038264513015747, 'black_gloves': 0.7703496813774109, 'thighs': 0.7689077854156494, 'holding': 0.768336832523346, 'looking_at_viewer': 0.739115834236145, 'large_breasts': 0.7282243967056274, 'sky': 0.6852632761001587, 'hair_between_eyes': 0.6799711585044861, 'stomach': 0.6694454550743103, 'baseball_bat': 0.6693665385246277, 'black_shorts': 0.6493985652923584, 'day': 0.6425715684890747, 'cowboy_shot': 0.6186742186546326, 'black_shirt': 0.5906491279602051, 'holding_baseball_bat': 0.5860112905502319, 'sweat': 0.5825777649879456, 'cloud': 0.5549533367156982, 'blue_sky': 0.5523971915245056, 'white_hair': 0.5324308276176453, 'grey_hair': 0.52657151222229, 'short_shorts': 0.4896492063999176, 'standing': 0.45526784658432007, 'parted_lips': 0.4306206703186035, 'blush': 0.4149143397808075, 'thigh_gap': 0.4124316871166229, 'ass_visible_through_thighs': 0.34030789136886597, 'artist_name': 0.2679593563079834, 'groin': 0.2652612328529358, 'blurry': 0.2548949122428894, 'baseball': 0.24870169162750244, 'crop_top_overhang': 0.2240566909313202, 'stretching': 0.2012709677219391, 'cropped_shirt': 0.19828352332115173, 'official_alternate_costume': 0.1960265338420868, 'toned': 0.13941210508346558, 'exercising': 0.11270403861999512, 'lens_flare': 0.10835999250411987, 'taut_clothes': 0.08783495426177979, 'taut_shirt': 0.08448180556297302, 'linea_alba': 0.06583884358406067}
        >>> artist
        {}
        >>> character
        {'skadi_(arknights)': 0.8951651453971863}
        >>> rating
        {'sensitive': 0.9492285847663879}

        Its default fmt is ``('general', 'artist', 'character', 'rating')``.

        But you can easily get those information you need with a more controllable way with ``fmt``. See the next note.

    .. note::
        The fmt argument can include the following keys:
        The ``fmt`` argument can include the following keys:

        - Category names: dicts containing category-specific tags and their confidences
        - ``tag``: a dict containing all tags across categories and their confidences
@@ -558,18 +582,39 @@ def multilabel_timm_predict(image: ImageTyping, repo_id: str,

        You can extract specific category predictions or all tags based on your needs.

        Example:

        >>> # Get all categories and tags
        >>> result = multilabel_timm_predict('image.jpg', 'username/model-name')
        >>> from imgutils.generic import multilabel_timm_predict
        >>>
        >>> # default usage
        >>> general, character, rating = multilabel_timm_predict(
        ...     'skadi.jpg',
        ...     repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
        ... )
        >>> general
        {'1girl': 0.9963783025741577, 'long_hair': 0.9685494899749756, 'solo': 0.9548443555831909, 'navel': 0.9415484666824341, 'breasts': 0.9369214177131653, 'red_eyes': 0.9019639492034912, 'shirt': 0.873087465763092, 'outdoors': 0.866461992263794, 'crop_top': 0.862577497959137, 'midriff': 0.8544420003890991, 'sportswear': 0.849435567855835, 'shorts': 0.8209151029586792, 'short_sleeves': 0.817188560962677, 'holding': 0.811793327331543, 'very_long_hair': 0.8082301616668701, 'gloves': 0.7840366363525391, 'black_gloves': 0.7765430808067322, 'thighs': 0.7542579770088196, 'looking_at_viewer': 0.7331588268280029, 'day': 0.7203925251960754, 'hair_between_eyes': 0.7121687531471252, 'large_breasts': 0.6990523338317871, 'baseball_bat': 0.6809443831443787, 'grey_hair': 0.6790007948875427, 'sky': 0.6716539263725281, 'stomach': 0.6698249578475952, 'sweat': 0.6454322934150696, 'black_shirt': 0.6270318031311035, 'cowboy_shot': 0.6216483116149902, 'blue_sky': 0.5898874998092651, 'black_shorts': 0.5445142984390259, 'holding_baseball_bat': 0.5013713836669922, 'white_hair': 0.4999670684337616, 'blush': 0.4860053062438965, 'cloud': 0.474183052778244, 'standing': 0.4724341332912445, 'thigh_gap': 0.4330931305885315, 'short_shorts': 0.39793258905410767, 'parted_lips': 0.36694538593292236, 'crop_top_overhang': 0.3321989178657532, 'official_alternate_costume': 0.3157039284706116, 'blurry': 0.24181532859802246, 'groin': 0.21906554698944092, 'ass_visible_through_thighs': 0.2188207507133484, 'cropped_shirt': 0.18700966238975525, 'taut_shirt': 0.08612403273582458, 'taut_clothes': 0.0701744556427002}
        >>> character
        {'skadi_(arknights)': 0.9796262979507446}
        >>> rating
        {'sensitive': 0.9580697417259216}
        >>>
        >>> # Get only specific categories
        >>> result = multilabel_timm_predict('image.jpg', 'username/model-name',
        ...                                  fmt=('general', 'character'))
        >>> # get rating and character only
        >>> rating, character = multilabel_timm_predict(
        ...     'skadi.jpg',
        ...     repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
        ...     fmt=('rating', 'character'),
        ... )
        >>> rating
        {'sensitive': 0.9580697417259216}
        >>> character
        {'skadi_(arknights)': 0.9796262979507446}
        >>>
        >>> # Get just the raw prediction values
        >>> prediction = multilabel_timm_predict('image.jpg', 'username/model-name',
        ...                                     fmt='prediction')
        >>> # get embeddings only
        >>> embedding = multilabel_timm_predict(
        ...     'skadi.jpg',
        ...     repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
        ...     fmt='embedding',
        ... )
        >>> embedding.dtype, embedding.shape
        (dtype('float32'), (1280,))
    """
    model = _open_models_for_repo_id(
        repo_id=repo_id,