Commit e2294aab authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): merge from main

parents 9184c694 6863fa52
Loading
Loading
Loading
Loading
+77 −0
Original line number Diff line number Diff line
from functools import lru_cache
from typing import Optional

import numpy as np
from PIL import Image, ImageFilter
from huggingface_hub import hf_hub_download
from scipy import signal

from ..data import ImageTyping, load_image
from ..utils import open_onnx_model

__all__ = [
    'get_monochrome_score',
    'is_monochrome',
]

_DEFAULT_MONOCHROME_CKPT = 'monochrome-resnet18-safe2-450.onnx'


@lru_cache()
def _monochrome_validate_model(ckpt):
    return open_onnx_model(hf_hub_download(
        'deepghs/imgutils-models',
        f'monochrome/{ckpt}'
    ))


def np_hist(x, a_min: float = 0.0, a_max: float = 1.0, bins: int = 256):
    x = np.asarray(x)
    edges = np.linspace(a_min, a_max, bins + 1)
    cnt, _ = np.histogram(x, bins=edges)
    return cnt / cnt.sum()


def butterworth_filter(r, fc):
    w = fc / (len(r) / 2)  # Normalize the frequency
    b, a = signal.butter(5, w, 'low')
    return np.clip(signal.filtfilt(b, a, r), a_min=0.0, a_max=1.0)


def _hsv_encode(image: Image.Image, feature_bins: int = 180, mf: Optional[int] = 5,
                maxpixels: int = 20000, fc: Optional[int] = 75, normalize: bool = True):
    if image.width * image.height > maxpixels:
        r = (image.width * image.height / maxpixels) ** 0.5
        new_width, new_height = map(lambda x: int(round(x / r)), image.size)
        image = image.resize((new_width, new_height))

    if mf is not None:
        image = image.filter(ImageFilter.MedianFilter(mf))
    image = image.convert('HSV')

    data = (np.transpose(np.asarray(image), (2, 0, 1)) / 255.0).astype(np.float32)
    channels = [np_hist(data[i], bins=feature_bins) for i in range(3)]
    if fc is not None:
        channels = [butterworth_filter(ch, fc) for ch in channels]

    dist = np.stack(channels)
    assert dist.shape == (3, feature_bins)

    if normalize:
        mean = np.mean(dist, axis=1, keepdims=True)
        std = np.std(dist, axis=1, keepdims=True, ddof=1)
        dist = (dist - mean) / std

    return dist


def get_monochrome_score(image: ImageTyping, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> float:
    image = load_image(image, mode='RGB')
    input_data = _hsv_encode(image).astype(np.float32)
    input_data = np.stack([input_data])
    output_data, = _monochrome_validate_model(ckpt).run(['output'], {'input': input_data})
    return float(output_data[0][1])


def is_monochrome(image: ImageTyping, threshold: float = 0.5, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> bool:
    return get_monochrome_score(image, ckpt) >= threshold
+2 −1
Original line number Diff line number Diff line
torch
torch<2
lpips
matplotlib
torchvision
@@ -12,3 +12,4 @@ di-toolkit
tensorboard
einops
thop
accelerate
+478 KiB
Loading image diff...
+42 −4
Original line number Diff line number Diff line
import os.path
import tempfile
from functools import partial
from typing import List, Tuple
from typing import List, Tuple, Optional

import click
import torch
from hbutils.testing import disable_output
from huggingface_hub import hf_hub_download
from tqdm.auto import tqdm

@@ -23,11 +25,45 @@ def cli():
    pass  # pragma: no cover


@cli.command('onnx_check', help='Check onnx export is okay or not')
@click.option('--model', '-m', 'model', type=click.Choice(list(_KNOWN_MODELS.keys())), default=None,
              help='Model to be checked. All models will be checked when not given.', show_default=True)
@click.option('--feature_bins', '-b', 'feature_bins', type=int, default=180,
              help='Feature bins of input.', show_default=True)
@click.option('--verbose', '-V', 'verbose', is_flag=True, type=bool, default=False,
              help='Show verbose information.', show_default=True)
@click.option('--output_dir', '-O', 'output_dir', type=click.Path(file_okay=False), default=None,
              help='Output directory of all models.', show_default=True)
def onnx_check(model: Optional[str] = None, feature_bins: int = 180, verbose: bool = False,
               output_dir: Optional[str] = None):
    if model:
        models = [model]
    else:
        models = list(_KNOWN_MODELS.keys())

    with tempfile.TemporaryDirectory() as td:
        for _model in models:
            click.echo(click.style(f'Try exporting {_model} to onnx ...'), nl=False)
            _torch_model = _KNOWN_MODELS[_model]().float()
            onnx_filename = os.path.join(output_dir or td, f'{_model}.onnx')
            try:
                if verbose:
                    export_model_to_onnx(_torch_model, onnx_filename, verbose=verbose, feature_bins=feature_bins)
                else:
                    with disable_output():
                        export_model_to_onnx(_torch_model, onnx_filename, verbose=verbose, feature_bins=feature_bins)
            except:
                click.echo(click.style('FAILED', fg='red'), nl=True)
                raise
            else:
                click.echo(click.style('OK', fg='green'), nl=True)


@cli.command('export_one', help='Export one model as onnx.',
             context_settings={**GLOBAL_CONTEXT_SETTINGS})
@click.option('--output', '-o', 'output', type=click.Path(dir_okay=False), required=True,
              help='Output path of feature model.', show_default=True)
@click.option('--feature_bins', '-b', 'feature_bins', type=int, default=256,
@click.option('--feature_bins', '-b', 'feature_bins', type=int, default=180,
              help='Feature bins of input.', show_default=True)
@click.option('--ckpt', '-c', 'ckpt', type=click.Path(exists=True, dir_okay=False), required=True,
              help='Checkpoint file to export.', show_default=True)
@@ -40,8 +76,10 @@ def export_one(output: str, feature_bins: int, ckpt: str, model_name: str):


_KNOWN_CKPTS: List[Tuple[str, str, int]] = [
    ('monochrome-alexnet_plus-320.ckpt', 'alexnet', 256),
    ('monochrome-alexnet_plus-500.ckpt', 'alexnet', 256),
    # ('monochrome-alexnet-480.ckpt', 'alexnet', 180),
    # ('monochrome-resnet18-480.ckpt', 'resnet18', 180),
    ('monochrome-transformer-480.ckpt', 'transformer', 180),
    # ('monochrome-resnet18-safe2-450.ckpt', 'resnet18', 180),
]


+1 −1
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import torch.nn as nn
class MonochromeAlexNet(nn.Module):
    __model_name__ = 'alexnet'

    def __init__(self, input_channels: int = 3, num_classes=2, avgpool_size: int = 7):
    def __init__(self, input_channels: int = 3, num_classes=2, avgpool_size: int = 4):
        super(MonochromeAlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(input_channels, 96, kernel_size=11, stride=4, padding=2),
Loading