Commit 6495a5c3 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add upscale

parent a488bf49
Loading
Loading
Loading
Loading
+15 −2
Original line number Diff line number Diff line
@@ -27,6 +27,9 @@ def _open_cdc_upscaler_model(model: str) -> Tuple[Any, int]:
    return ort, scale_h


_CDC_INPUT_UNIT = 16


def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320',
                     tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1,
                     silent: bool = False) -> Image.Image:
@@ -37,9 +40,19 @@ def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320
    ort, scale = _open_cdc_upscaler_model(model)

    def _method(ix):
        ox, = ort.run(['output'], {'input': ix.astype(np.float32)})
        ix = ix.astype(np.float32)
        batch, channels, height, width = ix.shape
        p_height = 0 if height % _CDC_INPUT_UNIT == 0 else _CDC_INPUT_UNIT - (height % _CDC_INPUT_UNIT)
        p_width = 0 if width % _CDC_INPUT_UNIT == 0 else _CDC_INPUT_UNIT - (width % _CDC_INPUT_UNIT)
        if p_height > 0 or p_width > 0:  # align to 16
            ix = np.pad(ix, ((0, 0), (0, 0), (0, p_height), (0, p_width)), mode='reflect')
        actual_height, actual_width = height, width

        ox, = ort.run(['output'], {'input': ix})
        batch, channels, scale_, height, scale_, width = ox.shape
        return ox.reshape((batch, channels, scale_ * height, scale_ * width))
        ox = ox.reshape((batch, channels, scale_ * height, scale_ * width))
        ox = ox[..., :scale_ * actual_height, :scale_ * actual_width]  # crop back
        return ox

    output_ = area_batch_run(
        input_, _method,
+2 −2
Original line number Diff line number Diff line
@@ -48,8 +48,8 @@ def area_batch_run(origin_input: np.ndarray, func, scale: int = 1,

    tile = min(tile_size, height, width)
    stride = tile - tile_overlap
    h_idx_list = list(range(0, height - tile, stride)) + [height - tile]
    w_idx_list = list(range(0, width - tile, stride)) + [width - tile]
    h_idx_list = sorted(set(list(range(0, height - tile, stride)) + [height - tile]))
    w_idx_list = sorted(set(list(range(0, width - tile, stride)) + [width - tile]))
    sum_ = np.zeros((batch, output_channels, height * scale, width * scale), dtype=origin_input.dtype)
    weight = np.zeros_like(sum_, dtype=origin_input.dtype)

+867 KiB
Loading image diff...
+2.42 MiB
Loading image diff...
+80.3 KiB
Loading image diff...
Loading