Loading imgutils/validate/monochrome.py +13 −8 Original line number Diff line number Diff line from functools import lru_cache from typing import Optional, Tuple from typing import Optional, Tuple, Mapping import numpy as np from PIL import Image Loading @@ -13,7 +13,11 @@ __all__ = [ 'is_monochrome', ] _DEFAULT_MONOCHROME_CKPT = 'monochrome-caformer_safe2-80.onnx' _MODELS: Mapping[int, str] = { 0: 'monochrome-caformer-110.onnx', 2: 'monochrome-caformer_safe2-80.onnx', 4: 'monochrome-caformer_safe4-70.onnx', } @lru_cache() Loading @@ -26,8 +30,6 @@ def _monochrome_validate_model(ckpt): def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): if image.mode != 'RGB': image = image.convert('RGB') image = image.resize(size, Image.BILINEAR) data = rgb_encode(image, order_='CHW') Loading @@ -40,13 +42,16 @@ def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), return data def get_monochrome_score(image: ImageTyping, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> float: def get_monochrome_score(image: ImageTyping, safe: int = 2) -> float: if safe not in _MODELS: raise ValueError(f'Safe level should be one of {set(sorted(_MODELS.keys()))!r}, but {safe!r} found.') image = load_image(image, mode='RGB') input_data = _2d_encode(image).astype(np.float32) input_data = np.stack([input_data]) output_data, = _monochrome_validate_model(ckpt).run(['output'], {'input': input_data}) output_data, = _monochrome_validate_model(_MODELS[safe]).run(['output'], {'input': input_data}) return float(output_data[0][1]) def is_monochrome(image: ImageTyping, threshold: float = 0.5, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> bool: return get_monochrome_score(image, ckpt) >= threshold def is_monochrome(image: ImageTyping, threshold: float = 0.5, safe: int = 2) -> bool: return get_monochrome_score(image, safe) >= threshold test/validate/test_monochrome.py +28 −7 Original line number Diff line number Diff line Loading @@ -4,9 +4,12 @@ import random import pytest from hbutils.random import keep_global_state, global_seed from hbutils.testing import tmatrix from imgutils.validate.monochrome import get_monochrome_score, is_monochrome _KNOWN_DUPS = {'2475192.jpg', '3842254.jpg', '2108110.jpg', '5257139.jpg', '6032011.jpg', '75719.jpg'} @keep_global_state() def get_samples(): Loading @@ -14,17 +17,35 @@ def get_samples(): all_samples_from_dataset = glob.glob( os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', '*', '*.jpg')) files = random.sample(all_samples_from_dataset, k=30) return sorted([(os.path.basename(os.path.dirname(file)), os.path.basename(file)) for file in files]) return sorted([ (os.path.basename(os.path.dirname(file)), os.path.basename(file)) for file in files if os.path.basename(file) not in _KNOWN_DUPS ]) @pytest.mark.unittest class TestValidateMonochrome: @pytest.mark.parametrize(['type_', 'file'], get_samples()) def test_monochrome_test(self, type_: str, file: str): @pytest.mark.parametrize(*tmatrix({ ('type_', 'file'): get_samples(), 'safe': [0, 2, 4], })) def test_monochrome_test(self, type_: str, file: str, safe: int): filename = os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', type_, file) if type_ == 'monochrome': assert get_monochrome_score(filename) >= 0.5 assert is_monochrome(filename) assert get_monochrome_score(filename, safe=safe) >= 0.5 assert is_monochrome(filename, safe=safe) else: assert get_monochrome_score(filename) <= 0.5 assert not is_monochrome(filename) assert get_monochrome_score(filename, safe=safe) <= 0.5 assert not is_monochrome(filename, safe=safe) def test_monochrome_test_with_unknown_safe(self): with pytest.raises(ValueError): _ = get_monochrome_score( os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', 'normal', '2475192.jpg'), safe=100 ) with pytest.raises(ValueError): _ = is_monochrome( os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', 'normal', '2475192.jpg'), safe=100 ) Loading
imgutils/validate/monochrome.py +13 −8 Original line number Diff line number Diff line from functools import lru_cache from typing import Optional, Tuple from typing import Optional, Tuple, Mapping import numpy as np from PIL import Image Loading @@ -13,7 +13,11 @@ __all__ = [ 'is_monochrome', ] _DEFAULT_MONOCHROME_CKPT = 'monochrome-caformer_safe2-80.onnx' _MODELS: Mapping[int, str] = { 0: 'monochrome-caformer-110.onnx', 2: 'monochrome-caformer_safe2-80.onnx', 4: 'monochrome-caformer_safe4-70.onnx', } @lru_cache() Loading @@ -26,8 +30,6 @@ def _monochrome_validate_model(ckpt): def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): if image.mode != 'RGB': image = image.convert('RGB') image = image.resize(size, Image.BILINEAR) data = rgb_encode(image, order_='CHW') Loading @@ -40,13 +42,16 @@ def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), return data def get_monochrome_score(image: ImageTyping, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> float: def get_monochrome_score(image: ImageTyping, safe: int = 2) -> float: if safe not in _MODELS: raise ValueError(f'Safe level should be one of {set(sorted(_MODELS.keys()))!r}, but {safe!r} found.') image = load_image(image, mode='RGB') input_data = _2d_encode(image).astype(np.float32) input_data = np.stack([input_data]) output_data, = _monochrome_validate_model(ckpt).run(['output'], {'input': input_data}) output_data, = _monochrome_validate_model(_MODELS[safe]).run(['output'], {'input': input_data}) return float(output_data[0][1]) def is_monochrome(image: ImageTyping, threshold: float = 0.5, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> bool: return get_monochrome_score(image, ckpt) >= threshold def is_monochrome(image: ImageTyping, threshold: float = 0.5, safe: int = 2) -> bool: return get_monochrome_score(image, safe) >= threshold
test/validate/test_monochrome.py +28 −7 Original line number Diff line number Diff line Loading @@ -4,9 +4,12 @@ import random import pytest from hbutils.random import keep_global_state, global_seed from hbutils.testing import tmatrix from imgutils.validate.monochrome import get_monochrome_score, is_monochrome _KNOWN_DUPS = {'2475192.jpg', '3842254.jpg', '2108110.jpg', '5257139.jpg', '6032011.jpg', '75719.jpg'} @keep_global_state() def get_samples(): Loading @@ -14,17 +17,35 @@ def get_samples(): all_samples_from_dataset = glob.glob( os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', '*', '*.jpg')) files = random.sample(all_samples_from_dataset, k=30) return sorted([(os.path.basename(os.path.dirname(file)), os.path.basename(file)) for file in files]) return sorted([ (os.path.basename(os.path.dirname(file)), os.path.basename(file)) for file in files if os.path.basename(file) not in _KNOWN_DUPS ]) @pytest.mark.unittest class TestValidateMonochrome: @pytest.mark.parametrize(['type_', 'file'], get_samples()) def test_monochrome_test(self, type_: str, file: str): @pytest.mark.parametrize(*tmatrix({ ('type_', 'file'): get_samples(), 'safe': [0, 2, 4], })) def test_monochrome_test(self, type_: str, file: str, safe: int): filename = os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', type_, file) if type_ == 'monochrome': assert get_monochrome_score(filename) >= 0.5 assert is_monochrome(filename) assert get_monochrome_score(filename, safe=safe) >= 0.5 assert is_monochrome(filename, safe=safe) else: assert get_monochrome_score(filename) <= 0.5 assert not is_monochrome(filename) assert get_monochrome_score(filename, safe=safe) <= 0.5 assert not is_monochrome(filename, safe=safe) def test_monochrome_test_with_unknown_safe(self): with pytest.raises(ValueError): _ = get_monochrome_score( os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', 'normal', '2475192.jpg'), safe=100 ) with pytest.raises(ValueError): _ = is_monochrome( os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', 'normal', '2475192.jpg'), safe=100 )