Commit 3fa1ef97 authored by dmMaze's avatar dmMaze
Browse files

add new inpainting model lama_large_512px and set as default, improved manga...

add new inpainting model lama_large_512px and set as default, improved manga inpainting #256, support bf16, fp16 inference for lower memory consumption #262
parent 45f57907
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -118,3 +118,8 @@ DEVICE_SELECTOR = lambda : deepcopy(
    }
)

TORCH_DTYPE_MAP = {
    'fp32': torch.float32,
    'fp16': torch.float16,
    'bf16': torch.bfloat16,
}
+82 −10
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from utils.registry import Registry
from utils.textblock_mask import extract_ballon_mask
from utils.imgproc_utils import enlarge_window

from ..base import BaseModule, DEFAULT_DEVICE, gc_collect, DEVICE_SELECTOR, GPUINTENSIVE_SET
from ..base import BaseModule, DEFAULT_DEVICE, gc_collect, DEVICE_SELECTOR, GPUINTENSIVE_SET, TORCH_DTYPE_MAP
from ..textdetector import TextBlock

INPAINTERS = Registry('inpainters')
@@ -52,7 +52,11 @@ class InpainterBase(BaseModule):
                                            if running into it frequently, consider lowering the inpaint_size')
                        self.moveToDevice('cpu')
                        inpainted = self._inpaint(img, mask, textblock_list)
                        self.moveToDevice('cuda')
                        precision = None
                        if hasattr(self, 'precision'):
                            precision = self.precision
                        self.moveToDevice('cuda', precision)

                        return inpainted
            else:
                raise e
@@ -108,7 +112,7 @@ class InpainterBase(BaseModule):
    def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
        raise NotImplementedError
    
    def moveToDevice(self, device: str):
    def moveToDevice(self, device: str, precision: str = None):
        raise not NotImplementedError


@@ -207,7 +211,7 @@ class AOTInpainter(InpainterBase):
            self.model.to(self.device)
        self.inpaint_size = int(self.params['inpaint_size']['select'])

    def moveToDevice(self, device: str):
    def moveToDevice(self, device: str, precision: str = None):
        self.model.to(device)
        self.device = device

@@ -288,6 +292,7 @@ class LamaInpainterMPE(InpainterBase):
        }, 
        'device': DEVICE_SELECTOR()
    }
    precision = 'fp32'

    device = DEFAULT_DEVICE
    inpaint_size = 2048
@@ -298,10 +303,6 @@ class LamaInpainterMPE(InpainterBase):
            'files': 'data/models/lama_mpe.ckpt',
    }]

    def moveToDevice(self, device: str):
        self.model.to(device)
        self.device = device

    def setup_inpainter(self):
        global LAMA_MPE

@@ -354,9 +355,20 @@ class LamaInpainterMPE(InpainterBase):

        im_h, im_w = img.shape[:2]
        img_torch, mask_torch, rel_pos, direct, img_original, mask_original, pad_bottom, pad_right = self.inpaint_preprocess(img, mask)
        
        precision = TORCH_DTYPE_MAP[self.precision]
        if self.device in {'cuda', 'mps'}:
            try:
                with torch.autocast(device_type=self.device, dtype=precision):
                    img_inpainted_torch = self.model(img_torch, mask_torch, rel_pos, direct)
            except Exception as e:
                self.logger.error(e)
                self.logger.error(f'{precision} inference is not supported for this device, use fp32 instead.')
                img_inpainted_torch = self.model(img_torch, mask_torch, rel_pos, direct)
        else:
            img_inpainted_torch = self.model(img_torch, mask_torch, rel_pos, direct)

        img_inpainted = (img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        img_inpainted = (img_inpainted_torch.to(device='cpu', dtype=torch.float32).squeeze_(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        if pad_bottom > 0:
            img_inpainted = img_inpainted[:-pad_bottom]
        if pad_right > 0:
@@ -379,6 +391,66 @@ class LamaInpainterMPE(InpainterBase):
        elif param_key == 'inpaint_size':
            self.inpaint_size = int(self.params['inpaint_size']['select'])

        elif param_key == 'precision':
            precision = self.params['precision']['select']
            self.precision = precision

    def moveToDevice(self, device: str, precision: str = None):
        self.model.to(device)
        self.device = device
        if precision is not None:
            self.precision = precision

LAMA_LARGE: LamaFourier = None
@register_inpainter('lama_large_512px')
class LamaLarge(LamaInpainterMPE):

    params = {
        'inpaint_size': {
            'type': 'selector',
            'options': [
                512,
                768,
                1024,
                1536, 
                2048
            ], 
            'select': 1536,
        }, 
        'device': DEVICE_SELECTOR(),
        'precision': {
            'type': 'selector',
            'options': [
                'fp16', 
                'fp32',
                'bf16'
            ], 
            'select': 'bf16'
        }, 
    }

    download_file_list = [{
            'url': 'https://huggingface.co/dreMaz/AnimeMangaInpainting/resolve/main/lama_large_512px.ckpt',
            'sha256_pre_calculated': '11d30fbb3000fb2eceae318b75d9ced9229d99ae990a7f8b3ac35c8d31f2c935',
            'files': 'data/models/lama_large_512px.ckpt',
    }]

    device = DEFAULT_DEVICE
    inpaint_size = 1024

    def setup_inpainter(self):
        global LAMA_LARGE

        device = self.params['device']['select']
        self.inpaint_size = int(self.params['inpaint_size']['select'])
        precision = self.params['precision']['select']

        if LAMA_LARGE is None:
            self.model = LAMA_LARGE = load_lama_mpe(r'data/models/lama_large_512px.ckpt', device='cpu', use_mpe=False, large_arch=True)
        else:
            self.model = LAMA_LARGE
        self.moveToDevice(device, precision=precision)


# LAMA_ORI: LamaFourier = None
# @register_inpainter('lama_ori')
+55 −6
Original line number Diff line number Diff line
@@ -69,6 +69,57 @@ class FourierUnit(nn.Module):
        self.ffc3d = ffc3d
        self.fft_norm = fft_norm

    # def forward(self, x):
    #     batch = x.shape[0]
    #     input_dtype = x.dtype

    #     if self.spatial_scale_factor is not None:
    #         orig_size = x.shape[-2:]
    #         x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)

    #     # (batch, c, h, w/2+1, 2)
    #     fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
    #     # x: torch.float16

    #     if input_dtype != torch.float32:
    #         x = x.type(torch.float32)
    #     ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
    #     ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
    #     ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()  # (batch, c, 2, h, w/2+1)
    #     ffted = ffted.view((batch, -1,) + ffted.size()[3:])

    #     if self.spectral_pos_encoding:
    #         height, width = ffted.shape[-2:]
    #         coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
    #         coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
    #         ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)

    #     if self.use_se:
    #         ffted = self.se(ffted)

    #     if ffted.dtype != input_dtype:
    #         ffted = ffted.type(input_dtype)
    #     ffted = self.conv_layer(ffted)  # (batch, c*2, h, w/2+1)
    #     ffted = self.relu(self.bn(ffted))

    #     ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
    #         0, 1, 3, 4, 2).contiguous()  # (batch,c, t, h, w/2+1, 2)
        
    #     if input_dtype != torch.float32:
    #         ffted = ffted.type(torch.float32)
    #     ffted = torch.complex(ffted[..., 0], ffted[..., 1])

    #     ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
    #     output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)

    #     if output.dtype != input_dtype:
    #         output = output.type(input_dtype)

    #     if self.spatial_scale_factor is not None:
    #         output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)

    #     return output

    def forward(self, x):
        batch = x.shape[0]

@@ -79,12 +130,10 @@ class FourierUnit(nn.Module):
        r_size = x.size()
        # (batch, c, h, w/2+1, 2)
        fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
        # x: torch.float16
        if x.dtype == torch.float16:
            half = True

        if x.dtype in (torch.float16, torch.bfloat16):
            x = x.type(torch.float32)
        else:
            half = False

        ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
        ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()  # (batch, c, 2, h, w/2+1)
@@ -104,7 +153,7 @@ class FourierUnit(nn.Module):

        ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
            0, 1, 3, 4, 2).contiguous()  # (batch,c, t, h, w/2+1, 2)
        if ffted.dtype == torch.float16:
        if ffted.dtype in (torch.float16, torch.bfloat16):
            ffted = ffted.type(torch.float32)
        ffted = torch.complex(ffted[..., 0], ffted[..., 1])

+11 −4
Original line number Diff line number Diff line
@@ -251,9 +251,15 @@ class MPE(nn.Module):


class LamaFourier:
    def __init__(self, build_discriminator=True, use_mpe=False) -> None:
    def __init__(self, build_discriminator=True, use_mpe=False, large_arch: bool = False) -> None:
        # super().__init__()

        n_blocks = 9
        if large_arch:
            n_blocks = 18
        
        self.generator = FFCResNetGenerator(4, 3, add_out_act='sigmoid', 
                            n_blocks = n_blocks,
                            init_conv_kwargs={
                            'ratio_gin': 0,
                            'ratio_gout': 0,
@@ -266,8 +272,9 @@ class LamaFourier:
                            'ratio_gin': 0.75,
                            'ratio_gout': 0.75,
                            'enable_lfu': False
                        }
                        }, 
                    )
        
        self.discriminator = NLayerDiscriminator() if build_discriminator else None
        self.inpaint_only = False
        if use_mpe:
@@ -413,8 +420,8 @@ class LamaFourier:

        return rel_pos, abs_pos, direct

def load_lama_mpe(model_path, device, use_mpe=True) -> LamaFourier:
    model = LamaFourier(build_discriminator=False, use_mpe=use_mpe)
def load_lama_mpe(model_path, device, use_mpe=True, large_arch: bool = False) -> LamaFourier:
    model = LamaFourier(build_discriminator=False, use_mpe=use_mpe, large_arch=large_arch)
    sd = torch.load(model_path, map_location = 'cpu')
    model.generator.load_state_dict(sd['gen_state_dict'])
    if use_mpe:
+1 −1
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from .structures import Tuple, Union, List, Dict, Config, field, nested_dataclas
class ModuleConfig(Config):
    textdetector: str = 'ctd'
    ocr: str = "mit48px_ctc"
    inpainter: str = 'lama_mpe'
    inpainter: str = 'lama_large_512px'
    translator: str = "google"
    enable_detect: bool = True
    enable_ocr: bool = True