Commit 295af1f4 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): replace aicheck

parent 6a6f07ce
Loading
Loading
Loading
Loading
+6 −5
Original line number Diff line number Diff line
import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.generic.classify import _open_models_for_repo_id
from imgutils.validate import get_ai_created_score
from imgutils.validate.aicheck import _MODEL_NAMES
from imgutils.validate.aicheck import _REPO_ID

_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names


class AnimeAICheckBenchmark(BaseBenchmark):
@@ -11,12 +14,10 @@ class AnimeAICheckBenchmark(BaseBenchmark):
        self.model = model

    def load(self):
        from imgutils.validate.aicheck import _open_anime_aicheck_model
        _ = _open_anime_aicheck_model(self.model)
        _open_models_for_repo_id(_REPO_ID)._open_model(self.model)

    def unload(self):
        from imgutils.validate.aicheck import _open_anime_aicheck_model
        _open_anime_aicheck_model.cache_clear()
        _open_models_for_repo_id(_REPO_ID).clear()

    def run(self):
        image_file = random.choice(self.all_images)
+0 −2454

File deleted.

Preview size limit exceeded, changes collapsed.

+48 −85
Original line number Diff line number Diff line
@@ -15,62 +15,72 @@ Overview:
    The models are hosted on
    `huggingface - deepghs/anime_ai_check <https://huggingface.co/deepghs/anime_ai_check>`_.
"""
from functools import lru_cache
from typing import Tuple, Optional

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from imgutils.data import rgb_encode, ImageTyping, load_image
from imgutils.utils import open_onnx_model
from ..data import ImageTyping
from ..generic import classify_predict, classify_predict_score

__all__ = [
    'get_ai_created_score',
    'is_ai_created',
]

_LABELS = ['ai', 'human']
_MODEL_NAMES = [
    'caformer_s36_plus_sce',
    'mobilenetv3_sce',
    'mobilenetv3_sce_dist',
]
_DEFAULT_MODEL_NAME = 'mobilenetv3_sce_dist'
_REPO_ID = 'deepghs/anime_ai_check'


@lru_cache()
def _open_anime_aicheck_model(model_name):
    return open_onnx_model(hf_hub_download(
        f'deepghs/anime_ai_check',
        f'{model_name}/model.onnx',
    ))


def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
                normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
    image = image.resize(size, Image.BILINEAR)
    data = rgb_encode(image, order_='CHW')
def get_ai_created_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> float:
    """
    Overview:
        Predict if the given image is created by AI (mainly by stable diffusion), given a score.

    if normalize is not None:
        mean_, std_ = normalize
        mean = np.asarray([mean_]).reshape((-1, 1, 1))
        std = np.asarray([std_]).reshape((-1, 1, 1))
        data = (data - mean) / std
    :param image: Image to be predicted.
    :param model_name: Name of the model. Default is ``mobilenetv3_sce_dist``.
        If you need better accuracy, use ``caformer_s36_plus_sce``.
        All the available values are listed on the benchmark graph.
    :return: A float number which represent the score of AI-check.

    return data.astype(np.float32)
    Examples::
        >>> from imgutils.validate import get_ai_created_score
        >>>
        >>> get_ai_created_score('aicheck/ai/1.jpg')
        0.9996960163116455
        >>> get_ai_created_score('aicheck/ai/2.jpg')
        0.9999125003814697
        >>> get_ai_created_score('aicheck/ai/3.jpg')
        0.997803270816803
        >>> get_ai_created_score('aicheck/ai/4.jpg')
        0.9960069060325623
        >>> get_ai_created_score('aicheck/ai/5.jpg')
        0.9887709021568298
        >>> get_ai_created_score('aicheck/ai/6.jpg')
        0.9998629093170166
        >>> get_ai_created_score('aicheck/human/7.jpg')
        0.0013722758740186691
        >>> get_ai_created_score('aicheck/human/8.jpg')
        0.00020673229300882667
        >>> get_ai_created_score('aicheck/human/9.jpg')
        0.0001895089662866667
        >>> get_ai_created_score('aicheck/human/10.jpg')
        0.0008857478387653828
        >>> get_ai_created_score('aicheck/human/11.jpg')
        4.552320024231449e-05
        >>> get_ai_created_score('aicheck/human/12.jpg')
        0.001168627175502479
    """
    return classify_predict_score(image, _REPO_ID, model_name)['ai']


def get_ai_created_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> float:
def is_ai_created(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, threshold: float = 0.5) -> bool:
    """
    Overview:
        Predict if the given image is created by AI (mainly by stable diffusion), given a score.
        Predict if the given image is created by AI (mainly by stable diffusion).

    :param image: Image to be predicted.
    :param model_name: Name of the model. Default is ``mobilenetv3_sce_dist``.
        If you need better accuracy, use ``caformer_s36_plus_sce``.
        All the available values are listed on the benchmark graph.
    :return: A float number which represent the score of AI-check.
    :param threshold: Threshold of the score. When the score is no less than ``threshold``, this image
        will be predicted as ``AI-created``. Default is ``0.5``.
    :return: This image is ``AI-created`` or not.

    Examples::
        >>> from imgutils.validate import is_ai_created
@@ -100,52 +110,5 @@ def get_ai_created_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NA
        >>> is_ai_created('aicheck/human/12.jpg')
        False
    """
    image = load_image(image, force_background='white', mode='RGB')
    input_ = _img_encode(image)[None, ...]
    output, = _open_anime_aicheck_model(model_name).run(['output'], {'input': input_})

    return output[0][0].item()


def is_ai_created(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, threshold: float = 0.5) -> bool:
    """
    Overview:
        Predict if the given image is created by AI (mainly by stable diffusion).

    :param image: Image to be predicted.
    :param model_name: Name of the model. Default is ``mobilenetv3_sce_dist``.
        If you need better accuracy, use ``caformer_s36_plus_sce``.
        All the available values are listed on the benchmark graph.
    :param threshold: Threshold of the score. When the score is no less than ``threshold``, this image
        will be predicted as ``AI-created``. Default is ``0.5``.
    :return: This image is ``AI-created`` or not.

    Examples::
        >>> from imgutils.validate import get_ai_created_score
        >>>
        >>> get_ai_created_score('aicheck/ai/1.jpg')
        0.9996960163116455
        >>> get_ai_created_score('aicheck/ai/2.jpg')
        0.9999125003814697
        >>> get_ai_created_score('aicheck/ai/3.jpg')
        0.997803270816803
        >>> get_ai_created_score('aicheck/ai/4.jpg')
        0.9960069060325623
        >>> get_ai_created_score('aicheck/ai/5.jpg')
        0.9887709021568298
        >>> get_ai_created_score('aicheck/ai/6.jpg')
        0.9998629093170166
        >>> get_ai_created_score('aicheck/human/7.jpg')
        0.0013722758740186691
        >>> get_ai_created_score('aicheck/human/8.jpg')
        0.00020673229300882667
        >>> get_ai_created_score('aicheck/human/9.jpg')
        0.0001895089662866667
        >>> get_ai_created_score('aicheck/human/10.jpg')
        0.0008857478387653828
        >>> get_ai_created_score('aicheck/human/11.jpg')
        4.552320024231449e-05
        >>> get_ai_created_score('aicheck/human/12.jpg')
        0.001168627175502479
    """
    return get_ai_created_score(image, model_name) >= threshold
    type_, _ = classify_predict(image, _REPO_ID, model_name)
    return type_ == 'ai'
+3 −2
Original line number Diff line number Diff line
@@ -3,7 +3,8 @@ import os.path

import pytest

from imgutils.validate.aicheck import _open_anime_aicheck_model, is_ai_created, get_ai_created_score
from imgutils.generic.classify import _open_models_for_repo_id
from imgutils.validate.aicheck import is_ai_created, get_ai_created_score, _REPO_ID
from test.testings import get_testfile

_ROOT_DIR = get_testfile('anime_aicheck')
@@ -18,7 +19,7 @@ def _release_model_after_run():
    try:
        yield
    finally:
        _open_anime_aicheck_model.cache_clear()
        _open_models_for_repo_id(_REPO_ID).clear()


@pytest.mark.unittest