Unverified Commit a1115722 authored by 6ri6ver's avatar 6ri6ver Committed by GitHub
Browse files

png transparency support (#962)



* add png transparency support

text detection/ocr/inpaint transparency fixes 0.1

method 2 transparency fix

shapes + text block inpaint transparency fix

full page inpainting transparency fix

* fix and cleanup code generated by claude

---------

Co-authored-by: default avatarno <yeah@gmail.com>
Co-authored-by: default avatardmMaze <beneathlimbo@gmail.com>
parent 194da254
Loading
Loading
Loading
Loading
+55 −8
Original line number Diff line number Diff line
@@ -15,6 +15,28 @@ INPAINTERS = Registry('inpainters')
register_inpainter = INPAINTERS.register_module


def inpaint_handle_alpha_channel(original_alpha, mask):
    '''
    perhaps a better idea is to feed the alpha into inpainting model, but it'll double the cost  
    for now it just return the original alpha
    '''

    result_alpha = original_alpha.copy()
    return result_alpha
    
    # Analyze the alpha values around the original mask to determine appropriate transparency
    mask_dilated = cv2.dilate((mask > 0).astype(np.uint8), np.ones((15, 15), np.uint8), iterations=1)
    surrounding_mask = mask_dilated - (mask > 0).astype(np.uint8)
    
    if np.any(surrounding_mask > 0):
        surrounding_alpha = original_alpha[surrounding_mask > 0]
        if len(surrounding_alpha) > 0:
            median_surrounding_alpha = np.median(surrounding_alpha)
            result_alpha[surrounding_mask] = median_surrounding_alpha

    return result_alpha


class InpainterBase(BaseModule):

    inpaint_by_block = True
@@ -62,24 +84,44 @@ class InpainterBase(BaseModule):
        if not self.all_model_loaded():
            self.load_model()
        
        # Handle RGBA images by preserving alpha channel
        original_alpha = None
        if len(img.shape) == 3 and img.shape[2] == 4:
            original_alpha = img[:, :, 3:4]  # Keep alpha channel
            img_rgb = img[:, :, :3]  # Use only RGB for inpainting
        else:
            img_rgb = img
        
        if not self.inpaint_by_block or textblock_list is None:
            if check_need_inpaint:
                ballon_msk, non_text_msk = extract_ballon_mask(img, mask)
                ballon_msk, non_text_msk = extract_ballon_mask(img_rgb, mask)
                if ballon_msk is not None:
                    non_text_region = np.where(non_text_msk > 0)
                    non_text_px = img[non_text_region]
                    non_text_px = img_rgb[non_text_region]
                    average_bg_color = np.median(non_text_px, axis=0)
                    std_rgb = np.std(non_text_px - average_bg_color, axis=0)
                    std_max = np.max(std_rgb)
                    inpaint_thresh = 7 if np.std(std_rgb) > 1 else 10
                    if std_max < inpaint_thresh:
                        img = img.copy()
                        img[np.where(ballon_msk > 0)] = average_bg_color
                        return img
            return self.memory_safe_inpaint(img, mask, textblock_list)
                        result_rgb = img_rgb.copy()
                        result_rgb[np.where(ballon_msk > 0)] = average_bg_color
                        # Recombine with alpha if original was RGBA
                        if original_alpha is not None:
                            return np.concatenate([result_rgb, original_alpha], axis=2)
                        return result_rgb
            result_rgb = self.memory_safe_inpaint(img_rgb, mask, textblock_list)
            # Recombine with alpha if original was RGBA
            if original_alpha is not None:
                result_alpha = inpaint_handle_alpha_channel(original_alpha, mask)
                return np.concatenate([result_rgb, result_alpha], axis=2)
            return result_rgb
        else:
            im_h, im_w = img.shape[:2]
            inpainted = np.copy(img)
            im_h, im_w = img_rgb.shape[:2]
            inpainted = np.copy(img_rgb)
            
            # Preserve original mask for transparency analysis
            original_mask = mask.copy()
            
            for blk in textblock_list:
                xyxy = blk.xyxy
                xyxy_e = enlarge_window(xyxy, im_w, im_h, ratio=1.7)
@@ -107,6 +149,11 @@ class InpainterBase(BaseModule):
                    inpainted[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]] = self.memory_safe_inpaint(im, msk)

                mask[xyxy[1]:xyxy[3], xyxy[0]:xyxy[2]] = 0
            
            # Recombine with alpha if original was RGBA
            if original_alpha is not None:
                result_alpha = inpaint_handle_alpha_channel(original_alpha, mask)
                return np.concatenate([inpainted, result_alpha], axis=2)
            return inpainted

    def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
+5 −1
Original line number Diff line number Diff line
@@ -573,7 +573,11 @@ class OCR32pxModel:
            region = np.zeros((N, self.text_height, max_width, 3), dtype = np.uint8)
            for i, idx in enumerate(indices) :
                W = regions[idx].shape[1]
                region[i, :, : W, :] = regions[idx]
                # Convert RGBA to RGB if necessary for model input
                region_data = regions[idx]
                if region_data.shape[2] == 4:
                    region_data = cv2.cvtColor(region_data, cv2.COLOR_RGBA2RGB)
                region[i, :, : W, :] = region_data
            images = (torch.from_numpy(region).float() - 127.5) / 127.5
            images = einops.rearrange(images, 'N H W C -> N C H W')
            if self.device != 'cpu':
+5 −1
Original line number Diff line number Diff line
@@ -150,7 +150,11 @@ class Model48pxOCR:
            region = np.zeros((N, self.text_height, max_width, 3), dtype = np.uint8)
            for i, idx in enumerate(indices):
                W = regions[idx].shape[1]
                region[i, :, : W, :]=regions[idx]
                # Convert RGBA to RGB if necessary for model input
                region_data = regions[idx]
                if region_data.shape[2] == 4:
                    region_data = cv2.cvtColor(region_data, cv2.COLOR_RGBA2RGB)
                region[i, :, : W, :]=region_data

            image_tensor = (torch.from_numpy(region).float() - 127.5) / 127.5
            image_tensor = einops.rearrange(image_tensor, 'N H W C -> N C H W')
+5 −1
Original line number Diff line number Diff line
@@ -75,7 +75,11 @@ class MangaOCR(OCRBase):
            x1, y1, x2, y2 = blk.xyxy
            if y2 < im_h and x2 < im_w and \
                x1 > 0 and y1 > 0 and x1 < x2 and y1 < y2: 
                blk.text = self.model(img[y1:y2, x1:x2])
                # Extract region and convert RGBA to RGB if necessary for model input
                region = img[y1:y2, x1:x2]
                if len(region.shape) == 3 and region.shape[2] == 4:
                    region = cv2.cvtColor(region, cv2.COLOR_RGBA2RGB)
                blk.text = self.model(region)
            else:
                self.logger.warning('invalid textbbox to target img')
                blk.text = ['']
+3 −0
Original line number Diff line number Diff line
@@ -232,6 +232,9 @@ if PADDLE_OCR_AVAILABLE:
                x1, y1, x2, y2 = blk.xyxy
                if 0 <= x1 < x2 <= im_w and 0 <= y1 < y2 <= im_h:
                    cropped_img = img[y1:y2, x1:x2]
                    # Convert RGBA to RGB if necessary for model input
                    if len(cropped_img.shape) == 3 and cropped_img.shape[2] == 4:
                        cropped_img = cv2.cvtColor(cropped_img, cv2.COLOR_RGBA2RGB)
                    try:
                        result = self.model.ocr(
                            cropped_img, det=True, rec=True, cls=self.use_angle_cls
Loading