Commit 9cb08b00 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add pydocs

parent e4aaf9db
Loading
Loading
Loading
Loading
+34 −1
Original line number Diff line number Diff line
@@ -12,6 +12,15 @@ from ..utils import open_onnx_model, area_batch_run

@lru_cache()
def _open_cdc_upscaler_model(model: str) -> Tuple[Any, int]:
    """
    Opens and initializes the CDC upscaler model.

    :param model: The name of the model to use.
    :type model: str

    :return: Tuple of the ONNX model and the scale factor.
    :rtype: Tuple[Any, int]
    """
    ort = open_onnx_model(hf_hub_download(
        f'deepghs/cdc_anime_onnx',
        f'{model}.onnx'
@@ -34,6 +43,30 @@ _CDC_INPUT_UNIT = 16
def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320',
                     tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1,
                     silent: bool = False) -> Image.Image:
    """
    Upscale the input image using the CDC upscaler model.

    :param image: The input image.
    :type image: ImageTyping

    :param model: The name of the model to use. (default: 'HGSR-MHR-anime-aug_X4_320')
    :type model: str

    :param tile_size: The size of each tile. (default: 512)
    :type tile_size: int

    :param tile_overlap: The overlap between tiles. (default: 64)
    :type tile_overlap: int

    :param batch_size: The batch size. (default: 1)
    :type batch_size: int

    :param silent: Whether to suppress progress messages. (default: False)
    :type silent: bool

    :return: The upscaled image.
    :rtype: Image.Image
    """
    image, alpha_mask = _rgba_preprocess(image)
    image = load_image(image, mode='RGB', force_background='white')
    input_ = np.array(image).astype(np.float32) / 255.0
@@ -62,5 +95,5 @@ def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320
        scale=scale, silent=silent, process_title='CDC Upscale',
    )
    output_ = np.clip(output_, a_min=0.0, a_max=1.0)
    ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
    ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.uint8), 'RGB')
    return _rgba_postprocess(ret_image, alpha_mask)