Commit b2d76ca1 authored by dmMaze's avatar dmMaze
Browse files

trying to impl extract_ballon_mask

parent 6e90d3d2
Loading
Loading
Loading
Loading
+54 −2
Original line number Diff line number Diff line
@@ -30,16 +30,68 @@ class InpainterBase(ModuleParamParser):
        raise NotImplementedError

    def inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
        def extract_ballon_mask(img: np.ndarray, mask: np.ndarray) -> np.ndarray:
            # img = cv2.GaussianBlur(img,(3,3),cv2.BORDER_DEFAULT)
            h, w = img.shape[:2]
            text_sum = np.sum(mask)
            cannyed = cv2.Canny(img, 70, 140, L2gradient=True, apertureSize=3)
            br = cv2.boundingRect(cv2.findNonZero(mask))
            br_xyxy = [br[0], br[1], br[0] + br[2], br[1] + br[3]]

            cv2.rectangle(cannyed, (0, 0), (w-1, h-1), (255, 255, 255), 1, cv2.LINE_8)
            cannyed = cv2.bitwise_and(cannyed, 255 - mask)

            cons, _ = cv2.findContours(cannyed, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
            min_ballon_area = w * h
            ballon_mask = None
            non_text_mask = None
            for ii, con in enumerate(cons):
                br_c = cv2.boundingRect(con)
                br_c = [br_c[0], br_c[1], br_c[0] + br_c[2], br_c[1] + br_c[3]]
                if br_c[0] > br_xyxy[0] or br_c[1] > br_xyxy[1] or br_c[2] < br_xyxy[2] or br_c[3] < br_xyxy[3]:
                    continue
                tmp = np.zeros_like(cannyed)
                cv2.drawContours(tmp, cons, ii, (255, 255, 255), -1, cv2.LINE_8)
                if cv2.bitwise_and(tmp, mask).sum() >= text_sum:
                    con_area = cv2.contourArea(con)
                    if con_area < min_ballon_area:
                        min_ballon_area = con_area
                        ballon_mask = tmp
            if ballon_mask is not None:
                non_text_mask = cv2.bitwise_and(ballon_mask, 255 - mask)
            #     cv2.imshow('ballon', ballon_mask)
            #     cv2.imshow('non_text', non_text_mask)
            # cv2.imshow('im', img)
            # cv2.imshow('msk', mask)
            # cv2.imshow('br', mask[br_xyxy[1]:br_xyxy[3], br_xyxy[0]:br_xyxy[2]])
            # cv2.imshow('canny', cannyed)
                
            # cv2.waitKey(0)
            # return msk
            return ballon_mask, non_text_mask


        if not self.inpaint_by_block or textblock_list is None:
            return self._inpaint(img, mask)
        else:
            im_h, im_w = img.shape[:2]
            inpainted = img
            inpainted = np.copy(img)
            for blk in textblock_list:
                xyxy = blk.xyxy
                xyxy_e = enlarge_window(xyxy, im_w, im_h, ratio=1.5)
                im = inpainted[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]]
                msk = mask[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]]
                # ballon_msk, non_text_msk = extract_ballon_mask(im, msk)
                # if ballon_msk is not None:
                #     non_text_region = np.where(non_text_msk > 0)
                #     non_text_px = im[non_text_region]
                #     average_bg_color = np.mean(non_text_px, axis=0)
                #     std = np.std(non_text_px - average_bg_color, axis=0)
                #     print(average_bg_color, std)
                #     cv2.imshow('im', im)
                #     cv2.imshow('ballon', ballon_msk)
                #     cv2.imshow('non_text', non_text_msk)
                #     cv2.waitKey(0)
                inpainted[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]] = self._inpaint(im, msk)
            return inpainted

+64 −64
Original line number Diff line number Diff line
@@ -168,71 +168,71 @@ def refine_mask(img: np.ndarray, pred_mask: np.ndarray, blk_list: List[TextBlock
        mask_refined[by1: by2, bx1: bx2] = cv2.bitwise_or(mask_refined[by1: by2, bx1: bx2], mask_merged)
    return mask_refined

# def extract_textballoon(img, pred_textmsk=None, global_mask=None):
#     if len(img.shape) > 2 and img.shape[2] == 3:
#         img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
#     im_h, im_w = img.shape[0], img.shape[1]
#     hyp_textmsk = np.zeros((im_h, im_w), np.uint8)
#     thresh_val, threshed = cv2.threshold(img, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
#     xormap_sum = cv2.bitwise_xor(threshed, pred_textmsk).sum()
#     neg_threshed = 255 - threshed
#     neg_xormap_sum = cv2.bitwise_xor(neg_threshed, pred_textmsk).sum()
#     neg_thresh = neg_xormap_sum < xormap_sum
#     if neg_thresh:
#         threshed = neg_threshed
#     thresh_info = {'thresh_val': thresh_val,'neg_thresh': neg_thresh}
#     connectivity = 8
#     num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(threshed, connectivity, cv2.CV_16U)
#     label_unchanged = np.copy(labels)
#     if global_mask is not None:
#         labels[np.where(global_mask==0)] = 0
#     text_labels = []
#     if pred_textmsk is not None:
#         text_score_thresh = 0.5
#         textbbox_map = np.zeros_like(pred_textmsk)
#         for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
#             if label_index != 0: # skip background label
#                 x, y, w, h, area = stat
#                 area *= 255
#                 x1, y1, x2, y2 = x, y, x+w, y+h
#                 label_local = labels[y1: y2, x1: x2]
#                 label_cordinates = np.where(label_local==label_index)
#                 tmp_merged = np.zeros((h, w), np.uint8)
#                 tmp_merged[label_cordinates] = 255
#                 andmap = cv2.bitwise_and(tmp_merged, pred_textmsk[y1: y2, x1: x2])
#                 text_score = andmap.sum() / area
#                 if text_score > text_score_thresh:
#                     text_labels.append(label_index)
#                     hyp_textmsk[y1: y2, x1: x2][label_cordinates] = 255
#     labels = label_unchanged
#     bubble_msk = np.zeros((img.shape[0], img.shape[1]), np.uint8)
#     bubble_msk[np.where(labels==0)] = 255
#     # if lang == LANG_JPN:
#     bubble_msk = cv2.erode(bubble_msk, (3, 3), iterations=1)
#     line_thickness = 2
#     cv2.rectangle(bubble_msk, (0, 0), (im_w, im_h), BLACK, line_thickness, cv2.LINE_8)
#     contours, hiers = cv2.findContours(bubble_msk, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
def extract_textballoon(img, pred_textmsk=None, global_mask=None):
    if len(img.shape) > 2 and img.shape[2] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    im_h, im_w = img.shape[0], img.shape[1]
    hyp_textmsk = np.zeros((im_h, im_w), np.uint8)
    thresh_val, threshed = cv2.threshold(img, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
    xormap_sum = cv2.bitwise_xor(threshed, pred_textmsk).sum()
    neg_threshed = 255 - threshed
    neg_xormap_sum = cv2.bitwise_xor(neg_threshed, pred_textmsk).sum()
    neg_thresh = neg_xormap_sum < xormap_sum
    if neg_thresh:
        threshed = neg_threshed
    thresh_info = {'thresh_val': thresh_val,'neg_thresh': neg_thresh}
    connectivity = 8
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(threshed, connectivity, cv2.CV_16U)
    label_unchanged = np.copy(labels)
    if global_mask is not None:
        labels[np.where(global_mask==0)] = 0
    text_labels = []
    if pred_textmsk is not None:
        text_score_thresh = 0.5
        textbbox_map = np.zeros_like(pred_textmsk)
        for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
            if label_index != 0: # skip background label
                x, y, w, h, area = stat
                area *= 255
                x1, y1, x2, y2 = x, y, x+w, y+h
                label_local = labels[y1: y2, x1: x2]
                label_cordinates = np.where(label_local==label_index)
                tmp_merged = np.zeros((h, w), np.uint8)
                tmp_merged[label_cordinates] = 255
                andmap = cv2.bitwise_and(tmp_merged, pred_textmsk[y1: y2, x1: x2])
                text_score = andmap.sum() / area
                if text_score > text_score_thresh:
                    text_labels.append(label_index)
                    hyp_textmsk[y1: y2, x1: x2][label_cordinates] = 255
    labels = label_unchanged
    bubble_msk = np.zeros((img.shape[0], img.shape[1]), np.uint8)
    bubble_msk[np.where(labels==0)] = 255
    # if lang == LANG_JPN:
    bubble_msk = cv2.erode(bubble_msk, (3, 3), iterations=1)
    line_thickness = 2
    cv2.rectangle(bubble_msk, (0, 0), (im_w, im_h), BLACK, line_thickness, cv2.LINE_8)
    contours, hiers = cv2.findContours(bubble_msk, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)

#     brect_area_thresh = im_h * im_w * 0.4
#     min_brect_area = np.inf
#     ballon_index = -1
#     maxium_pixsum = -1
#     for ii, contour in enumerate(contours):
#         brect = cv2.boundingRect(contours[ii])
#         brect_area = brect[2] * brect[3]
#         if brect_area > brect_area_thresh and brect_area < min_brect_area:
#             tmp_ballonmsk = np.zeros_like(bubble_msk)
#             tmp_ballonmsk = cv2.drawContours(tmp_ballonmsk, contours, ii, WHITE, cv2.FILLED)
#             andmap_sum = cv2.bitwise_and(tmp_ballonmsk, hyp_textmsk).sum()
#             if andmap_sum > maxium_pixsum:
#                 maxium_pixsum = andmap_sum
#                 min_brect_area = brect_area
#                 ballon_index = ii
#     if ballon_index != -1:
#         bubble_msk = np.zeros_like(bubble_msk)
#         bubble_msk = cv2.drawContours(bubble_msk, contours, ballon_index, WHITE, cv2.FILLED)
#     hyp_textmsk = cv2.bitwise_and(hyp_textmsk, bubble_msk)
#     return hyp_textmsk, bubble_msk, thresh_info, (num_labels, label_unchanged, stats, centroids, text_labels)
    brect_area_thresh = im_h * im_w * 0.4
    min_brect_area = np.inf
    ballon_index = -1
    maxium_pixsum = -1
    for ii, contour in enumerate(contours):
        brect = cv2.boundingRect(contours[ii])
        brect_area = brect[2] * brect[3]
        if brect_area > brect_area_thresh and brect_area < min_brect_area:
            tmp_ballonmsk = np.zeros_like(bubble_msk)
            tmp_ballonmsk = cv2.drawContours(tmp_ballonmsk, contours, ii, WHITE, cv2.FILLED)
            andmap_sum = cv2.bitwise_and(tmp_ballonmsk, hyp_textmsk).sum()
            if andmap_sum > maxium_pixsum:
                maxium_pixsum = andmap_sum
                min_brect_area = brect_area
                ballon_index = ii
    if ballon_index != -1:
        bubble_msk = np.zeros_like(bubble_msk)
        bubble_msk = cv2.drawContours(bubble_msk, contours, ballon_index, WHITE, cv2.FILLED)
    hyp_textmsk = cv2.bitwise_and(hyp_textmsk, bubble_msk)
    return hyp_textmsk, bubble_msk, thresh_info, (num_labels, label_unchanged, stats, centroids, text_labels)

# def extract_textballoon_channelwise(img, pred_textmsk, test_grey=True, global_mask=None):
#     c_list = [img[:, :, i] for i in range(3)]
+1 −10
Original line number Diff line number Diff line
import sys, os
import os.path as osp
sys.path.append(osp.dirname(osp.dirname(__file__)))
 No newline at end of file


from test_translators import dosth




if __name__ == '__main__':
    print(dosth())
 No newline at end of file