Loading modules/base.py +5 −0 Original line number Diff line number Diff line Loading @@ -118,3 +118,8 @@ DEVICE_SELECTOR = lambda : deepcopy( } ) TORCH_DTYPE_MAP = { 'fp32': torch.float32, 'fp16': torch.float16, 'bf16': torch.bfloat16, } modules/inpaint/base.py +82 −10 Original line number Diff line number Diff line Loading @@ -8,7 +8,7 @@ from utils.registry import Registry from utils.textblock_mask import extract_ballon_mask from utils.imgproc_utils import enlarge_window from ..base import BaseModule, DEFAULT_DEVICE, gc_collect, DEVICE_SELECTOR, GPUINTENSIVE_SET from ..base import BaseModule, DEFAULT_DEVICE, gc_collect, DEVICE_SELECTOR, GPUINTENSIVE_SET, TORCH_DTYPE_MAP from ..textdetector import TextBlock INPAINTERS = Registry('inpainters') Loading Loading @@ -52,7 +52,11 @@ class InpainterBase(BaseModule): if running into it frequently, consider lowering the inpaint_size') self.moveToDevice('cpu') inpainted = self._inpaint(img, mask, textblock_list) self.moveToDevice('cuda') precision = None if hasattr(self, 'precision'): precision = self.precision self.moveToDevice('cuda', precision) return inpainted else: raise e Loading Loading @@ -108,7 +112,7 @@ class InpainterBase(BaseModule): def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray: raise NotImplementedError def moveToDevice(self, device: str): def moveToDevice(self, device: str, precision: str = None): raise not NotImplementedError Loading Loading @@ -207,7 +211,7 @@ class AOTInpainter(InpainterBase): self.model.to(self.device) self.inpaint_size = int(self.params['inpaint_size']['select']) def moveToDevice(self, device: str): def moveToDevice(self, device: str, precision: str = None): self.model.to(device) self.device = device Loading Loading @@ -288,6 +292,7 @@ class LamaInpainterMPE(InpainterBase): }, 'device': DEVICE_SELECTOR() } precision = 'fp32' device = DEFAULT_DEVICE inpaint_size = 2048 Loading @@ -298,10 +303,6 @@ class LamaInpainterMPE(InpainterBase): 'files': 'data/models/lama_mpe.ckpt', }] def moveToDevice(self, device: str): self.model.to(device) self.device = device def setup_inpainter(self): global LAMA_MPE Loading Loading @@ -354,9 +355,20 @@ class LamaInpainterMPE(InpainterBase): im_h, im_w = img.shape[:2] img_torch, mask_torch, rel_pos, direct, img_original, mask_original, pad_bottom, pad_right = self.inpaint_preprocess(img, mask) precision = TORCH_DTYPE_MAP[self.precision] if self.device in {'cuda', 'mps'}: try: with torch.autocast(device_type=self.device, dtype=precision): img_inpainted_torch = self.model(img_torch, mask_torch, rel_pos, direct) except Exception as e: self.logger.error(e) self.logger.error(f'{precision} inference is not supported for this device, use fp32 instead.') img_inpainted_torch = self.model(img_torch, mask_torch, rel_pos, direct) else: img_inpainted_torch = self.model(img_torch, mask_torch, rel_pos, direct) img_inpainted = (img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8) img_inpainted = (img_inpainted_torch.to(device='cpu', dtype=torch.float32).squeeze_(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8) if pad_bottom > 0: img_inpainted = img_inpainted[:-pad_bottom] if pad_right > 0: Loading @@ -379,6 +391,66 @@ class LamaInpainterMPE(InpainterBase): elif param_key == 'inpaint_size': self.inpaint_size = int(self.params['inpaint_size']['select']) elif param_key == 'precision': precision = self.params['precision']['select'] self.precision = precision def moveToDevice(self, device: str, precision: str = None): self.model.to(device) self.device = device if precision is not None: self.precision = precision LAMA_LARGE: LamaFourier = None @register_inpainter('lama_large_512px') class LamaLarge(LamaInpainterMPE): params = { 'inpaint_size': { 'type': 'selector', 'options': [ 512, 768, 1024, 1536, 2048 ], 'select': 1536, }, 'device': DEVICE_SELECTOR(), 'precision': { 'type': 'selector', 'options': [ 'fp16', 'fp32', 'bf16' ], 'select': 'bf16' }, } download_file_list = [{ 'url': 'https://huggingface.co/dreMaz/AnimeMangaInpainting/resolve/main/lama_large_512px.ckpt', 'sha256_pre_calculated': '11d30fbb3000fb2eceae318b75d9ced9229d99ae990a7f8b3ac35c8d31f2c935', 'files': 'data/models/lama_large_512px.ckpt', }] device = DEFAULT_DEVICE inpaint_size = 1024 def setup_inpainter(self): global LAMA_LARGE device = self.params['device']['select'] self.inpaint_size = int(self.params['inpaint_size']['select']) precision = self.params['precision']['select'] if LAMA_LARGE is None: self.model = LAMA_LARGE = load_lama_mpe(r'data/models/lama_large_512px.ckpt', device='cpu', use_mpe=False, large_arch=True) else: self.model = LAMA_LARGE self.moveToDevice(device, precision=precision) # LAMA_ORI: LamaFourier = None # @register_inpainter('lama_ori') Loading modules/inpaint/ffc.py +55 −6 Original line number Diff line number Diff line Loading @@ -69,6 +69,57 @@ class FourierUnit(nn.Module): self.ffc3d = ffc3d self.fft_norm = fft_norm # def forward(self, x): # batch = x.shape[0] # input_dtype = x.dtype # if self.spatial_scale_factor is not None: # orig_size = x.shape[-2:] # x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False) # # (batch, c, h, w/2+1, 2) # fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) # # x: torch.float16 # if input_dtype != torch.float32: # x = x.type(torch.float32) # ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) # ffted = torch.stack((ffted.real, ffted.imag), dim=-1) # ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) # ffted = ffted.view((batch, -1,) + ffted.size()[3:]) # if self.spectral_pos_encoding: # height, width = ffted.shape[-2:] # coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted) # coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted) # ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) # if self.use_se: # ffted = self.se(ffted) # if ffted.dtype != input_dtype: # ffted = ffted.type(input_dtype) # ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) # ffted = self.relu(self.bn(ffted)) # ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( # 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) # if input_dtype != torch.float32: # ffted = ffted.type(torch.float32) # ffted = torch.complex(ffted[..., 0], ffted[..., 1]) # ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] # output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) # if output.dtype != input_dtype: # output = output.type(input_dtype) # if self.spatial_scale_factor is not None: # output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False) # return output def forward(self, x): batch = x.shape[0] Loading @@ -79,12 +130,10 @@ class FourierUnit(nn.Module): r_size = x.size() # (batch, c, h, w/2+1, 2) fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) # x: torch.float16 if x.dtype == torch.float16: half = True if x.dtype in (torch.float16, torch.bfloat16): x = x.type(torch.float32) else: half = False ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) ffted = torch.stack((ffted.real, ffted.imag), dim=-1) ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) Loading @@ -104,7 +153,7 @@ class FourierUnit(nn.Module): ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) if ffted.dtype == torch.float16: if ffted.dtype in (torch.float16, torch.bfloat16): ffted = ffted.type(torch.float32) ffted = torch.complex(ffted[..., 0], ffted[..., 1]) Loading modules/inpaint/lama.py +11 −4 Original line number Diff line number Diff line Loading @@ -251,9 +251,15 @@ class MPE(nn.Module): class LamaFourier: def __init__(self, build_discriminator=True, use_mpe=False) -> None: def __init__(self, build_discriminator=True, use_mpe=False, large_arch: bool = False) -> None: # super().__init__() n_blocks = 9 if large_arch: n_blocks = 18 self.generator = FFCResNetGenerator(4, 3, add_out_act='sigmoid', n_blocks = n_blocks, init_conv_kwargs={ 'ratio_gin': 0, 'ratio_gout': 0, Loading @@ -266,8 +272,9 @@ class LamaFourier: 'ratio_gin': 0.75, 'ratio_gout': 0.75, 'enable_lfu': False } }, ) self.discriminator = NLayerDiscriminator() if build_discriminator else None self.inpaint_only = False if use_mpe: Loading Loading @@ -413,8 +420,8 @@ class LamaFourier: return rel_pos, abs_pos, direct def load_lama_mpe(model_path, device, use_mpe=True) -> LamaFourier: model = LamaFourier(build_discriminator=False, use_mpe=use_mpe) def load_lama_mpe(model_path, device, use_mpe=True, large_arch: bool = False) -> LamaFourier: model = LamaFourier(build_discriminator=False, use_mpe=use_mpe, large_arch=large_arch) sd = torch.load(model_path, map_location = 'cpu') model.generator.load_state_dict(sd['gen_state_dict']) if use_mpe: Loading utils/config.py +1 −1 Original line number Diff line number Diff line Loading @@ -9,7 +9,7 @@ from .structures import Tuple, Union, List, Dict, Config, field, nested_dataclas class ModuleConfig(Config): textdetector: str = 'ctd' ocr: str = "mit48px_ctc" inpainter: str = 'lama_mpe' inpainter: str = 'lama_large_512px' translator: str = "google" enable_detect: bool = True enable_ocr: bool = True Loading Loading
modules/base.py +5 −0 Original line number Diff line number Diff line Loading @@ -118,3 +118,8 @@ DEVICE_SELECTOR = lambda : deepcopy( } ) TORCH_DTYPE_MAP = { 'fp32': torch.float32, 'fp16': torch.float16, 'bf16': torch.bfloat16, }
modules/inpaint/base.py +82 −10 Original line number Diff line number Diff line Loading @@ -8,7 +8,7 @@ from utils.registry import Registry from utils.textblock_mask import extract_ballon_mask from utils.imgproc_utils import enlarge_window from ..base import BaseModule, DEFAULT_DEVICE, gc_collect, DEVICE_SELECTOR, GPUINTENSIVE_SET from ..base import BaseModule, DEFAULT_DEVICE, gc_collect, DEVICE_SELECTOR, GPUINTENSIVE_SET, TORCH_DTYPE_MAP from ..textdetector import TextBlock INPAINTERS = Registry('inpainters') Loading Loading @@ -52,7 +52,11 @@ class InpainterBase(BaseModule): if running into it frequently, consider lowering the inpaint_size') self.moveToDevice('cpu') inpainted = self._inpaint(img, mask, textblock_list) self.moveToDevice('cuda') precision = None if hasattr(self, 'precision'): precision = self.precision self.moveToDevice('cuda', precision) return inpainted else: raise e Loading Loading @@ -108,7 +112,7 @@ class InpainterBase(BaseModule): def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray: raise NotImplementedError def moveToDevice(self, device: str): def moveToDevice(self, device: str, precision: str = None): raise not NotImplementedError Loading Loading @@ -207,7 +211,7 @@ class AOTInpainter(InpainterBase): self.model.to(self.device) self.inpaint_size = int(self.params['inpaint_size']['select']) def moveToDevice(self, device: str): def moveToDevice(self, device: str, precision: str = None): self.model.to(device) self.device = device Loading Loading @@ -288,6 +292,7 @@ class LamaInpainterMPE(InpainterBase): }, 'device': DEVICE_SELECTOR() } precision = 'fp32' device = DEFAULT_DEVICE inpaint_size = 2048 Loading @@ -298,10 +303,6 @@ class LamaInpainterMPE(InpainterBase): 'files': 'data/models/lama_mpe.ckpt', }] def moveToDevice(self, device: str): self.model.to(device) self.device = device def setup_inpainter(self): global LAMA_MPE Loading Loading @@ -354,9 +355,20 @@ class LamaInpainterMPE(InpainterBase): im_h, im_w = img.shape[:2] img_torch, mask_torch, rel_pos, direct, img_original, mask_original, pad_bottom, pad_right = self.inpaint_preprocess(img, mask) precision = TORCH_DTYPE_MAP[self.precision] if self.device in {'cuda', 'mps'}: try: with torch.autocast(device_type=self.device, dtype=precision): img_inpainted_torch = self.model(img_torch, mask_torch, rel_pos, direct) except Exception as e: self.logger.error(e) self.logger.error(f'{precision} inference is not supported for this device, use fp32 instead.') img_inpainted_torch = self.model(img_torch, mask_torch, rel_pos, direct) else: img_inpainted_torch = self.model(img_torch, mask_torch, rel_pos, direct) img_inpainted = (img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8) img_inpainted = (img_inpainted_torch.to(device='cpu', dtype=torch.float32).squeeze_(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8) if pad_bottom > 0: img_inpainted = img_inpainted[:-pad_bottom] if pad_right > 0: Loading @@ -379,6 +391,66 @@ class LamaInpainterMPE(InpainterBase): elif param_key == 'inpaint_size': self.inpaint_size = int(self.params['inpaint_size']['select']) elif param_key == 'precision': precision = self.params['precision']['select'] self.precision = precision def moveToDevice(self, device: str, precision: str = None): self.model.to(device) self.device = device if precision is not None: self.precision = precision LAMA_LARGE: LamaFourier = None @register_inpainter('lama_large_512px') class LamaLarge(LamaInpainterMPE): params = { 'inpaint_size': { 'type': 'selector', 'options': [ 512, 768, 1024, 1536, 2048 ], 'select': 1536, }, 'device': DEVICE_SELECTOR(), 'precision': { 'type': 'selector', 'options': [ 'fp16', 'fp32', 'bf16' ], 'select': 'bf16' }, } download_file_list = [{ 'url': 'https://huggingface.co/dreMaz/AnimeMangaInpainting/resolve/main/lama_large_512px.ckpt', 'sha256_pre_calculated': '11d30fbb3000fb2eceae318b75d9ced9229d99ae990a7f8b3ac35c8d31f2c935', 'files': 'data/models/lama_large_512px.ckpt', }] device = DEFAULT_DEVICE inpaint_size = 1024 def setup_inpainter(self): global LAMA_LARGE device = self.params['device']['select'] self.inpaint_size = int(self.params['inpaint_size']['select']) precision = self.params['precision']['select'] if LAMA_LARGE is None: self.model = LAMA_LARGE = load_lama_mpe(r'data/models/lama_large_512px.ckpt', device='cpu', use_mpe=False, large_arch=True) else: self.model = LAMA_LARGE self.moveToDevice(device, precision=precision) # LAMA_ORI: LamaFourier = None # @register_inpainter('lama_ori') Loading
modules/inpaint/ffc.py +55 −6 Original line number Diff line number Diff line Loading @@ -69,6 +69,57 @@ class FourierUnit(nn.Module): self.ffc3d = ffc3d self.fft_norm = fft_norm # def forward(self, x): # batch = x.shape[0] # input_dtype = x.dtype # if self.spatial_scale_factor is not None: # orig_size = x.shape[-2:] # x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False) # # (batch, c, h, w/2+1, 2) # fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) # # x: torch.float16 # if input_dtype != torch.float32: # x = x.type(torch.float32) # ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) # ffted = torch.stack((ffted.real, ffted.imag), dim=-1) # ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) # ffted = ffted.view((batch, -1,) + ffted.size()[3:]) # if self.spectral_pos_encoding: # height, width = ffted.shape[-2:] # coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted) # coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted) # ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) # if self.use_se: # ffted = self.se(ffted) # if ffted.dtype != input_dtype: # ffted = ffted.type(input_dtype) # ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) # ffted = self.relu(self.bn(ffted)) # ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( # 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) # if input_dtype != torch.float32: # ffted = ffted.type(torch.float32) # ffted = torch.complex(ffted[..., 0], ffted[..., 1]) # ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] # output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) # if output.dtype != input_dtype: # output = output.type(input_dtype) # if self.spatial_scale_factor is not None: # output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False) # return output def forward(self, x): batch = x.shape[0] Loading @@ -79,12 +130,10 @@ class FourierUnit(nn.Module): r_size = x.size() # (batch, c, h, w/2+1, 2) fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) # x: torch.float16 if x.dtype == torch.float16: half = True if x.dtype in (torch.float16, torch.bfloat16): x = x.type(torch.float32) else: half = False ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) ffted = torch.stack((ffted.real, ffted.imag), dim=-1) ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) Loading @@ -104,7 +153,7 @@ class FourierUnit(nn.Module): ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) if ffted.dtype == torch.float16: if ffted.dtype in (torch.float16, torch.bfloat16): ffted = ffted.type(torch.float32) ffted = torch.complex(ffted[..., 0], ffted[..., 1]) Loading
modules/inpaint/lama.py +11 −4 Original line number Diff line number Diff line Loading @@ -251,9 +251,15 @@ class MPE(nn.Module): class LamaFourier: def __init__(self, build_discriminator=True, use_mpe=False) -> None: def __init__(self, build_discriminator=True, use_mpe=False, large_arch: bool = False) -> None: # super().__init__() n_blocks = 9 if large_arch: n_blocks = 18 self.generator = FFCResNetGenerator(4, 3, add_out_act='sigmoid', n_blocks = n_blocks, init_conv_kwargs={ 'ratio_gin': 0, 'ratio_gout': 0, Loading @@ -266,8 +272,9 @@ class LamaFourier: 'ratio_gin': 0.75, 'ratio_gout': 0.75, 'enable_lfu': False } }, ) self.discriminator = NLayerDiscriminator() if build_discriminator else None self.inpaint_only = False if use_mpe: Loading Loading @@ -413,8 +420,8 @@ class LamaFourier: return rel_pos, abs_pos, direct def load_lama_mpe(model_path, device, use_mpe=True) -> LamaFourier: model = LamaFourier(build_discriminator=False, use_mpe=use_mpe) def load_lama_mpe(model_path, device, use_mpe=True, large_arch: bool = False) -> LamaFourier: model = LamaFourier(build_discriminator=False, use_mpe=use_mpe, large_arch=large_arch) sd = torch.load(model_path, map_location = 'cpu') model.generator.load_state_dict(sd['gen_state_dict']) if use_mpe: Loading
utils/config.py +1 −1 Original line number Diff line number Diff line Loading @@ -9,7 +9,7 @@ from .structures import Tuple, Union, List, Dict, Config, field, nested_dataclas class ModuleConfig(Config): textdetector: str = 'ctd' ocr: str = "mit48px_ctc" inpainter: str = 'lama_mpe' inpainter: str = 'lama_large_512px' translator: str = "google" enable_detect: bool = True enable_ocr: bool = True Loading