Commit 44591903 authored by dmMaze's avatar dmMaze
Browse files

fix alpha inpainting

parent 00e4a6a2
Loading
Loading
Loading
Loading
+10 −8
Original line number Diff line number Diff line
@@ -22,21 +22,23 @@ def inpaint_handle_alpha_channel(original_alpha, mask):
    '''

    result_alpha = original_alpha.copy()
    return result_alpha

    # Analyze the alpha values around the original mask to determine appropriate transparency
    mask_dilated = cv2.dilate((mask > 0).astype(np.uint8), np.ones((15, 15), np.uint8), iterations=1)
    surrounding_mask = mask_dilated - (mask > 0).astype(np.uint8)
    mask_dilated = cv2.dilate((mask > 127).astype(np.uint8), np.ones((15, 15), np.uint8), iterations=1)
    surrounding_mask = mask_dilated - (mask > 127).astype(np.uint8)

    if np.any(surrounding_mask > 0):
        surrounding_alpha = original_alpha[surrounding_mask > 0]
        if len(surrounding_alpha) > 0:
            median_surrounding_alpha = np.median(surrounding_alpha)
            result_alpha[surrounding_mask] = median_surrounding_alpha
            # If surrounding area is mostly transparent (median alpha < 128),
            # make inpainted areas transparent too
            if median_surrounding_alpha < 128:
                inpainted_mask = (mask > 127)
                result_alpha[inpainted_mask] = median_surrounding_alpha

    return result_alpha


class InpainterBase(BaseModule):

    inpaint_by_block = True
@@ -152,7 +154,7 @@ class InpainterBase(BaseModule):
            
            # Recombine with alpha if original was RGBA
            if original_alpha is not None:
                result_alpha = inpaint_handle_alpha_channel(original_alpha, mask)
                result_alpha = inpaint_handle_alpha_channel(original_alpha, original_mask)
                return np.concatenate([inpainted, result_alpha], axis=2)
            return inpainted

+1 −0
Original line number Diff line number Diff line
@@ -803,6 +803,7 @@ class DrawingPanel(Widget):
            balloon_areas = np.where(ballon_mask > 0)
            if len(img.shape) == 3 and img.shape[2] == 4:
                avg_alpha = np.mean(img[balloon_areas][..., 3])
                avg_alpha = 0 if avg_alpha < 127 else avg_alpha
                bg_pixel_value.append(avg_alpha)
            bg_pixel_value = np.array(np.round(bg_pixel_value), dtype=np.uint8)
            img[balloon_areas] = bg_pixel_value