Commit 045a4d5b authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add pydoc for methods

parent b742ff97
Loading
Loading
Loading
Loading
+92 −2
Original line number Diff line number Diff line
"""
Overview:
    A tool for assessing the aesthetic quality of anime images using a pre-trained model.

"""
from typing import Dict, Optional, Tuple

import numpy as np
@@ -25,6 +30,16 @@ _DEFAULT_LABEL_MAPPING = {


def _value_replace(v, mapping):
    """
    Replaces values in a data structure using a mapping dictionary.

    :param v: The input data structure.
    :type v: Any
    :param mapping: A dictionary mapping values to replacement values.
    :type mapping: Dict
    :return: The modified data structure.
    :rtype: Any
    """
    if isinstance(v, (list, tuple)):
        return type(v)([_value_replace(vitem, mapping) for vitem in v])
    elif isinstance(v, dict):
@@ -39,16 +54,44 @@ def _value_replace(v, mapping):


class AestheticModel:
    """
    A model for assessing the aesthetic quality of anime images.
    """

    def __init__(self, repo_id: str):
        """
        Initializes an AestheticModel instance.

        :param repo_id: The repository ID of the aesthetic assessment model.
        :type repo_id: str
        """
        self.repo_id = repo_id
        self.classifier = ClassifyModel(repo_id)
        self.cached_samples: Dict[str, Tuple] = {}

    def get_aesthetic_score(self, image: ImageTyping, model_name: str) -> Tuple[float, Dict[str, float]]:
        """
        Calculates the aesthetic score and confidence for an anime image.

        :param image: The input anime image.
        :type image: ImageTyping
        :param model_name: The name of the aesthetic assessment model to use.
        :type model_name: str
        :return: A tuple containing the aesthetic score and confidence.
        :rtype: Tuple[float, Dict[str, float]]
        """
        scores = self.classifier.predict_score(image, model_name)
        return sum(scores[label] * i for i, label in enumerate(_LABELS)), scores

    def _get_xy_samples(self, model_name: str):
        """
        Retrieves cached samples for aesthetic assessment.

        :param model_name: The name of the aesthetic assessment model.
        :type model_name: str
        :return: Cached samples for aesthetic assessment.
        :rtype: Tuple[Tuple[np.ndarray, float, float], Tuple[np.ndarray, float, float]]
        """
        if model_name not in self.cached_samples:
            stacked = np.load(hf_hub_download(
                repo_id=self.repo_id,
@@ -59,7 +102,17 @@ class AestheticModel:
            self.cached_samples[model_name] = ((x, x.min(), x.max()), (y, y.min(), y.max()))
        return self.cached_samples[model_name]

    def score_to_percentile(self, score: float, model_name: str):
    def score_to_percentile(self, score: float, model_name: str) -> float:
        """
        Converts an aesthetic score to a percentile rank.

        :param score: The aesthetic score.
        :type score: float
        :param model_name: The name of the aesthetic assessment model to use.
        :type model_name: str
        :return: The percentile rank corresponding to the given score.
        :rtype: float
        """
        (x, x_min, x_max), (y, y_min, y_max) = self._get_xy_samples(model_name)
        idx = np.searchsorted(x, np.clip(score, a_min=x_min, a_max=x_max))
        if idx < x.shape[0] - 1:
@@ -73,7 +126,17 @@ class AestheticModel:
            return y[idx]

    @classmethod
    def percentile_to_label(cls, percentile: float, mapping: Optional[Dict[str, float]] = None):
    def percentile_to_label(cls, percentile: float, mapping: Optional[Dict[str, float]] = None) -> str:
        """
        Converts a percentile rank to an aesthetic label.

        :param percentile: The percentile rank.
        :type percentile: float
        :param mapping: A dictionary mapping labels to percentile thresholds.
        :type mapping: Optional[Dict[str, float]]
        :return: The aesthetic label corresponding to the given percentile rank.
        :rtype: str
        """
        mapping = mapping or _DEFAULT_LABEL_MAPPING
        for label, threshold in sorted(mapping.items(), key=lambda x: (-x[1], x[0])):
            if percentile >= threshold:
@@ -82,6 +145,18 @@ class AestheticModel:
            raise ValueError(f'No label for unknown percentile {percentile:.3f}.')

    def get_aesthetic(self, image: ImageTyping, model_name: str, fmt=('label', 'percentile')):
        """
        Analyzes the aesthetic quality of an anime image and returns the results in the specified format.

        :param image: The input anime image.
        :type image: ImageTyping
        :param model_name: The name of the aesthetic assessment model to use.
        :type model_name: str
        :param fmt: The format of the output.
        :type fmt: Tuple[str, ...]
        :return: A dictionary containing the aesthetic assessment results.
        :rtype: Dict[str, float]
        """
        score, confidence = self.get_aesthetic_score(image, model_name)
        percentile = self.score_to_percentile(score, model_name)
        label = self.percentile_to_label(percentile)
@@ -96,6 +171,9 @@ class AestheticModel:
        )

    def clear(self):
        """
        Clears the internal state of the AestheticModel instance.
        """
        self.classifier.clear()
        self.cached_samples.clear()

@@ -105,4 +183,16 @@ _MODEL = AestheticModel(_REPO_ID)

def anime_dbaesthetic(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME,
                      fmt=('label', 'percentile')):
    """
    Analyzes the aesthetic quality of an anime image using a pre-trained model.

    :param image: The input anime image.
    :type image: ImageTyping
    :param model_name: The name of the aesthetic assessment model to use. Default is _DEFAULT_MODEL_NAME.
    :type model_name: str
    :param fmt: The format of the output. Default is ('label', 'percentile').
    :type fmt: Tuple[str, ...]
    :return: A dictionary containing the aesthetic assessment results.
    :rtype: Dict[str, float]
    """
    return _MODEL.get_aesthetic(image, model_name, fmt)