Commit b3ba741b authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add waifu2x

parent deb5f041
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
from .waifu2x import upscale_image_by_waifu2x
+88 −0
Original line number Diff line number Diff line
import os.path
import re
from functools import lru_cache
from typing import Optional, Mapping, Tuple

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download, HfFileSystem

from ..data import ImageTyping, load_image
from ..utils import open_onnx_model, area_batch_run

_hf_fs = HfFileSystem()

_REPOSITORY = 'deepghs/waifu2x_onnx'
_FILENAME_PATTERN = re.compile(r'^(noise(?P<noise>\d+)_?)?(scale(?P<scale>\d+)x)?\.onnx$')


@lru_cache()
def _load_available_version() -> Mapping[Tuple[str, str, str], Mapping[Tuple[Optional[int], int], str]]:
    records = {}
    for file in _hf_fs.glob(f'{_REPOSITORY}/*/onnx_models/*/*/*.onnx'):
        segments = os.path.relpath(file, _REPOSITORY).split('/')
        assert len(segments) == 5 and segments[1] == 'onnx_models'
        version, model, type_ = segments[0], segments[2], segments[3]
        filename = segments[4]

        key = (version, model, type_)
        if key not in records:
            records[key] = []
        records[key].append(filename)

    retval = {}
    for key, filenames in records.items():
        retval[key] = {}
        for filename in filenames:
            matching = _FILENAME_PATTERN.fullmatch(filename)
            assert matching, f'Not matched, {filename!r}, key: {key!r}'
            noise = int(matching.group('noise')) if matching.group('noise') else None
            scale = int(matching.group('scale')) if matching.group('scale') else 1
            retval[key][(noise, scale)] = filename

    return retval


@lru_cache()
def _open_waifu2x_onnx_model(version: str, model: str, type_: str, noise: Optional[int], scale: int):
    _all_versions = _load_available_version()
    if (version, model, type_) in _all_versions:
        _all_k = _all_versions[(version, model, type_)]
        if (noise, scale) in _all_k:
            filename = _all_k[(noise, scale)]
            return open_onnx_model(hf_hub_download(
                f'deepghs/waifu2x_onnx',
                f'{version}/onnx_models/{model}/{type_}/{filename}',
            ))
        else:
            raise ValueError(f'Noise {noise!r} or scale {scale!r} may be invalid.')
    else:
        raise ValueError(f"Version {version!r} or model {model!r} or type_ {type_!r} may be invalid.")


def _single_upscale_by_waifu2x(x, version: str = '20230504', model: str = 'swin_unet',
                               type_: str = 'art', noise: Optional[int] = None, scale: int = 2):
    ort = _open_waifu2x_onnx_model(version, model, type_, noise, scale)
    # noinspection PyTypeChecker
    x = np.pad(x, ((0, 0), (0, 0), (8, 8), (8, 8)), mode='reflect')
    y, = ort.run(['y'], {'x': x})
    return y


def upscale_image_by_waifu2x(image: ImageTyping, scale: int = 2, noise: Optional[int] = None,
                             version: str = '20230504', model: str = 'swin_unet', type_: str = 'art',
                             tile_size: int = 64, tile_overlap: int = 8, silent: bool = False) -> Image.Image:
    image = load_image(image, mode='RGB', force_background='white')
    input_ = np.array(image).astype(np.float32) / 255.0
    input_ = input_.transpose((2, 0, 1))[None, ...]

    def _method(ix):
        return _single_upscale_by_waifu2x(ix, version, model, type_, noise, scale)

    output_ = area_batch_run(
        input_, _method,
        scale=scale, tile_size=tile_size, tile_overlap=tile_overlap, silent=silent,
        process_title='Waifu2x Upscale',
    )
    output_ = np.clip(output_, a_min=0.0, a_max=1.0)
    return Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
+2 −0
Original line number Diff line number Diff line
@@ -2,4 +2,6 @@
Overview:
    Generic utilities for :mod:`imgutils`.
"""
from .area import *
from .onnxruntime import *
from .tqdm_ import *

imgutils/utils/area.py

0 → 100644
+82 −0
Original line number Diff line number Diff line
import math

import numpy as np

from .tqdm_ import tqdm

__all__ = ['area_batch_run']


def area_batch_run(origin_input: np.ndarray, func, scale: int = 1,
                   tile_size: int = 512, tile_overlap: int = 16, batch_size: int = 4,
                   input_channels: int = 3, output_channels: int = 3, silent: bool = False,
                   process_title: str = 'Process Tiles', rebuild_title: str = 'Rebuild Tiles'):
    """
    Perform a batch execution of a given function on overlapping tiles of a large image.

    This function divides the original input image into tiles, applies a given function to each tile,
    and then reconstructs the image by combining the processed tiles.

    :param origin_input: The original input image as a NumPy ndarray with shape (batch, channels, height, width).
    :type origin_input: np.ndarray
    :param func: The function to apply to each tile. It should accept a tile (np.ndarray)
        as input and return the processed tile.
    :type func: callable
    :param scale: Scaling factor for output, defaults to 1.
    :type scale: int, optional
    :param tile_size: Size of the tiles, defaults to 512.
    :type tile_size: int, optional
    :param tile_overlap: Overlap between adjacent tiles, defaults to 16.
    :type tile_overlap: int, optional
    :param batch_size: Batch size for processing tiles, defaults to 4.
    :type batch_size: int, optional
    :param input_channels: Number of input channels, defaults to 3.
    :type input_channels: int, optional
    :param output_channels: Number of output channels, defaults to 3.
    :type output_channels: int, optional
    :param silent: If True, suppresses the progress bar output, defaults to False.
    :type silent: bool, optional
    :param process_title: Title for the processing progress bar, defaults to 'Process Tiles'.
    :type process_title: str, optional
    :param rebuild_title: Title for the rebuilding progress bar, defaults to 'Rebuild Tiles'.
    :type rebuild_title: str, optional
    :return: Processed image as a NumPy ndarray.
    :rtype: np.ndarray
    """
    batch, channels, height, width = origin_input.shape
    assert channels == input_channels, f'Input channels {input_channels!r} expected, but {channels!r} found.'

    tile = min(tile_size, height, width)
    stride = tile - tile_overlap
    h_idx_list = list(range(0, height - tile, stride)) + [height - tile]
    w_idx_list = list(range(0, width - tile, stride)) + [width - tile]
    sum_ = np.zeros((batch, output_channels, height * scale, width * scale), dtype=origin_input.dtype)
    weight = np.zeros_like(sum_, dtype=origin_input.dtype)

    all_patch = []
    all_idx = []

    with tqdm(total=math.ceil(len(h_idx_list) * len(w_idx_list) / batch_size),
              desc=process_title, silent=silent) as pbar:
        for h_idx in h_idx_list:
            for w_idx in w_idx_list:
                in_patch = origin_input[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
                all_patch.append(in_patch)
                all_idx.append((h_idx, w_idx))

        results = []
        for i in range(0, len(all_patch), batch_size):
            input_ = np.concatenate(all_patch[i:i + batch_size])
            output_ = func(input_)
            for idx, (h_idx, w_idx) in enumerate(all_idx[i:i + batch_size]):
                results.append((h_idx, w_idx, output_[idx]))
            pbar.update()

    for h_idx, w_idx, output_ in tqdm(results, desc=rebuild_title, silent=silent):
        out_patch_mask = np.ones_like(output_)
        h_min, h_max = h_idx * scale, (h_idx + tile) * scale
        w_min, w_max = w_idx * scale, (w_idx + tile) * scale
        sum_[..., h_min:h_max, w_min:w_max] += output_
        weight[..., h_min:h_max, w_min:w_max] += out_patch_mask

    return sum_ / weight
+25 −0
Original line number Diff line number Diff line
import io

from tqdm.auto import tqdm as _origin_tqdm

__all__ = ['tqdm']


def tqdm(*args, silent: bool = False, **kwargs):
    """
    An enhanced version of tqdm (progress bar) with an option to silence the output.

    This function modifies the behavior of tqdm to allow silencing the progress bar.

    :param args: Positional arguments to be passed to tqdm.
    :param silent: If True, the progress bar content will not be displayed.
    :type silent: bool
    :param kwargs: Additional keyword arguments to be passed to tqdm.
    :return: tqdm progress bar.
    :rtype: tqdm.std.tqdm
    """
    with io.StringIO() as sio:
        if silent:
            kwargs['file'] = sio

        return _origin_tqdm(*args, **kwargs)