Commit d0700f78 authored by dmMaze's avatar dmMaze
Browse files

try to fix color deviation for inpainting #733

parent 1316ceca
Loading
Loading
Loading
Loading
+6 −4
Original line number Diff line number Diff line
@@ -68,7 +68,7 @@ class InpainterBase(BaseModule):
                if ballon_msk is not None:
                    non_text_region = np.where(non_text_msk > 0)
                    non_text_px = img[non_text_region]
                    average_bg_color = np.mean(non_text_px, axis=0)
                    average_bg_color = np.median(non_text_px, axis=0)
                    std_bgr = np.std(non_text_px - average_bg_color, axis=0)
                    std_max = np.max(std_bgr)
                    inpaint_thresh = 7 if np.std(std_bgr) > 1 else 10
@@ -91,7 +91,7 @@ class InpainterBase(BaseModule):
                    if ballon_msk is not None:
                        non_text_region = np.where(non_text_msk > 0)
                        non_text_px = im[non_text_region]
                        average_bg_color = np.mean(non_text_px, axis=0)
                        average_bg_color = np.median(non_text_px, axis=0)
                        std_bgr = np.std(non_text_px - average_bg_color, axis=0)
                        std_max = np.max(std_bgr)
                        inpaint_thresh = 7 if np.std(std_bgr) > 1 else 10
@@ -253,7 +253,8 @@ class AOTInpainter(InpainterBase):
        im_h, im_w = img.shape[:2]
        img_torch, mask_torch, img_original, mask_original, pad_bottom, pad_right = self.inpaint_preprocess(img, mask)
        img_inpainted_torch = self.model(img_torch, mask_torch)
        img_inpainted = ((img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() + 1.0) * 127.5).astype(np.uint8)
        img_inpainted = ((img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() + 1.0) * 127.5)
        img_inpainted = (np.clip(np.round(img_inpainted), 0, 255)).astype(np.uint8)
        if pad_bottom > 0:
            img_inpainted = img_inpainted[:-pad_bottom]
        if pad_right > 0:
@@ -366,7 +367,8 @@ class LamaInpainterMPE(InpainterBase):
        else:
            img_inpainted_torch = self.model(img_torch, mask_torch, rel_pos, direct)

        img_inpainted = (img_inpainted_torch.to(device='cpu', dtype=torch.float32).squeeze_(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        img_inpainted = (img_inpainted_torch.to(device='cpu', dtype=torch.float32).squeeze_(0).permute(1, 2, 0).numpy() * 255)
        img_inpainted = (np.clip(np.round(img_inpainted), 0, 255)).astype(np.uint8)
        if pad_bottom > 0:
            img_inpainted = img_inpainted[:-pad_bottom]
        if pad_right > 0: