Commit 9009e85e authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add unittest for monochrome functions

parent bfd2c328
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -61,4 +61,4 @@ def _open_onnx_model(ckpt: str, provider: str) -> InferenceSession:


def open_onnx_model(ckpt: str, mode: str = None) -> InferenceSession:
    return _open_onnx_model(ckpt, get_onnx_provider(mode))
    return _open_onnx_model(ckpt, get_onnx_provider(mode or os.environ.get('ONNX_MODE', None)))
+3 −47
Original line number Diff line number Diff line
@@ -2,10 +2,8 @@ from functools import lru_cache
from typing import Optional, Tuple

import numpy as np
from PIL import Image, ImageFilter
from PIL.Image import Resampling
from PIL import Image
from huggingface_hub import hf_hub_download
from scipy import signal

from ..data import ImageTyping, load_image, rgb_encode
from ..utils import open_onnx_model
@@ -15,8 +13,7 @@ __all__ = [
    'is_monochrome',
]

# _DEFAULT_MONOCHROME_CKPT = 'monochrome-resnet18-safe2-450.onnx'
_DEFAULT_MONOCHROME_CKPT = 'monochrome-levit_d0.2-500.onnx'
_DEFAULT_MONOCHROME_CKPT = 'monochrome-caformer_safe2-80.onnx'


@lru_cache()
@@ -27,51 +24,11 @@ def _monochrome_validate_model(ckpt):
    ))


def np_hist(x, a_min: float = 0.0, a_max: float = 1.0, bins: int = 256):
    x = np.asarray(x)
    edges = np.linspace(a_min, a_max, bins + 1)
    cnt, _ = np.histogram(x, bins=edges)
    return cnt / cnt.sum()


def butterworth_filter(r, fc):
    w = fc / (len(r) / 2)  # Normalize the frequency
    b, a = signal.butter(5, w, 'low')
    return np.clip(signal.filtfilt(b, a, r), a_min=0.0, a_max=1.0)


def _hsv_encode(image: Image.Image, feature_bins: int = 180, mf: Optional[int] = 5,
                maxpixels: int = 20000, fc: Optional[int] = 75, normalize: bool = True):
    if image.width * image.height > maxpixels:
        r = (image.width * image.height / maxpixels) ** 0.5
        new_width, new_height = map(lambda x: int(round(x / r)), image.size)
        image = image.resize((new_width, new_height))

    if mf is not None:
        image = image.filter(ImageFilter.MedianFilter(mf))
    image = image.convert('HSV')

    data = (np.transpose(np.asarray(image), (2, 0, 1)) / 255.0).astype(np.float32)
    channels = [np_hist(data[i], bins=feature_bins) for i in range(3)]
    if fc is not None:
        channels = [butterworth_filter(ch, fc) for ch in channels]

    dist = np.stack(channels)
    assert dist.shape == (3, feature_bins)

    if normalize:
        mean = np.mean(dist, axis=1, keepdims=True)
        std = np.std(dist, axis=1, keepdims=True, ddof=1)
        dist = (dist - mean) / std

    return dist


def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
               normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image = image.resize(size, Resampling.BILINEAR)
    image = image.resize(size, Image.BILINEAR)
    data = rgb_encode(image, order_='CHW')

    if normalize is not None:
@@ -85,7 +42,6 @@ def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),

def get_monochrome_score(image: ImageTyping, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> float:
    image = load_image(image, mode='RGB')
    # input_data = _hsv_encode(image).astype(np.float32)
    input_data = _2d_encode(image).astype(np.float32)
    input_data = np.stack([input_data])
    output_data, = _monochrome_validate_model(ckpt).run(['output'], {'input': input_data})
+30 −0
Original line number Diff line number Diff line
import glob
import os.path
import random

import pytest
from hbutils.random import keep_global_state, global_seed

from imgutils.validate.monochrome import get_monochrome_score, is_monochrome


@keep_global_state()
def get_samples():
    global_seed(0)
    all_samples_from_dataset = glob.glob(
        os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', '*', '*.jpg'))
    files = random.sample(all_samples_from_dataset, k=30)
    return sorted([(os.path.basename(os.path.dirname(file)), os.path.basename(file)) for file in files])


@pytest.mark.unittest
class TestValidateMonochrome:
    @pytest.mark.parametrize(['type_', 'file'], get_samples())
    def test_monochrome_test(self, type_: str, file: str):
        filename = os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', type_, file)
        if type_ == 'monochrome':
            assert get_monochrome_score(filename) >= 0.5
            assert is_monochrome(filename)
        else:
            assert get_monochrome_score(filename) <= 0.5
            assert not is_monochrome(filename)