Commit 205effb7 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add test for 0, 4 safety models

parent 0e164f43
Loading
Loading
Loading
Loading
+13 −8
Original line number Diff line number Diff line
from functools import lru_cache
from typing import Optional, Tuple
from typing import Optional, Tuple, Mapping

import numpy as np
from PIL import Image
@@ -13,7 +13,11 @@ __all__ = [
    'is_monochrome',
]

_DEFAULT_MONOCHROME_CKPT = 'monochrome-caformer_safe2-80.onnx'
_MODELS: Mapping[int, str] = {
    0: 'monochrome-caformer-110.onnx',
    2: 'monochrome-caformer_safe2-80.onnx',
    4: 'monochrome-caformer_safe4-70.onnx',
}


@lru_cache()
@@ -26,8 +30,6 @@ def _monochrome_validate_model(ckpt):

def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
               normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image = image.resize(size, Image.BILINEAR)
    data = rgb_encode(image, order_='CHW')

@@ -40,13 +42,16 @@ def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
    return data


def get_monochrome_score(image: ImageTyping, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> float:
def get_monochrome_score(image: ImageTyping, safe: int = 2) -> float:
    if safe not in _MODELS:
        raise ValueError(f'Safe level should be one of {set(sorted(_MODELS.keys()))!r}, but {safe!r} found.')

    image = load_image(image, mode='RGB')
    input_data = _2d_encode(image).astype(np.float32)
    input_data = np.stack([input_data])
    output_data, = _monochrome_validate_model(ckpt).run(['output'], {'input': input_data})
    output_data, = _monochrome_validate_model(_MODELS[safe]).run(['output'], {'input': input_data})
    return float(output_data[0][1])


def is_monochrome(image: ImageTyping, threshold: float = 0.5, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> bool:
    return get_monochrome_score(image, ckpt) >= threshold
def is_monochrome(image: ImageTyping, threshold: float = 0.5, safe: int = 2) -> bool:
    return get_monochrome_score(image, safe) >= threshold
+28 −7
Original line number Diff line number Diff line
@@ -4,9 +4,12 @@ import random

import pytest
from hbutils.random import keep_global_state, global_seed
from hbutils.testing import tmatrix

from imgutils.validate.monochrome import get_monochrome_score, is_monochrome

_KNOWN_DUPS = {'2475192.jpg', '3842254.jpg', '2108110.jpg', '5257139.jpg', '6032011.jpg', '75719.jpg'}


@keep_global_state()
def get_samples():
@@ -14,17 +17,35 @@ def get_samples():
    all_samples_from_dataset = glob.glob(
        os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', '*', '*.jpg'))
    files = random.sample(all_samples_from_dataset, k=30)
    return sorted([(os.path.basename(os.path.dirname(file)), os.path.basename(file)) for file in files])
    return sorted([
        (os.path.basename(os.path.dirname(file)), os.path.basename(file))
        for file in files if os.path.basename(file) not in _KNOWN_DUPS
    ])


@pytest.mark.unittest
class TestValidateMonochrome:
    @pytest.mark.parametrize(['type_', 'file'], get_samples())
    def test_monochrome_test(self, type_: str, file: str):
    @pytest.mark.parametrize(*tmatrix({
        ('type_', 'file'): get_samples(),
        'safe': [0, 2, 4],
    }))
    def test_monochrome_test(self, type_: str, file: str, safe: int):
        filename = os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', type_, file)
        if type_ == 'monochrome':
            assert get_monochrome_score(filename) >= 0.5
            assert is_monochrome(filename)
            assert get_monochrome_score(filename, safe=safe) >= 0.5
            assert is_monochrome(filename, safe=safe)
        else:
            assert get_monochrome_score(filename) <= 0.5
            assert not is_monochrome(filename)
            assert get_monochrome_score(filename, safe=safe) <= 0.5
            assert not is_monochrome(filename, safe=safe)

    def test_monochrome_test_with_unknown_safe(self):
        with pytest.raises(ValueError):
            _ = get_monochrome_score(
                os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', 'normal', '2475192.jpg'),
                safe=100
            )
        with pytest.raises(ValueError):
            _ = is_monochrome(
                os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', 'normal', '2475192.jpg'),
                safe=100
            )