Commit 012d1260 authored by dmMaze's avatar dmMaze
Browse files

impl extract_ballon_mask

parent b2d76ca1
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ mask
result
data/models
data/testpacks/eng_dontupload
data/testpacks/testpacks

tmp.py
dummy_scripts.py
+25 −54
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ from typing import Dict, List


from utils.registry import Registry
from utils.textblock_mask import canny_flood, connected_canny_flood
from utils.textblock_mask import canny_flood, connected_canny_flood, extract_ballon_mask
from utils.imgproc_utils import enlarge_window

INPAINTERS = Registry('inpainters')
@@ -17,6 +17,7 @@ from ..textdetector import TextBlock
class InpainterBase(ModuleParamParser):

    inpaint_by_block = True
    check_need_inpaint = True
    def __init__(self, **setup_params) -> None:
        super().__init__(**setup_params)
        self.name = ''
@@ -30,46 +31,6 @@ class InpainterBase(ModuleParamParser):
        raise NotImplementedError

    def inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
        def extract_ballon_mask(img: np.ndarray, mask: np.ndarray) -> np.ndarray:
            # img = cv2.GaussianBlur(img,(3,3),cv2.BORDER_DEFAULT)
            h, w = img.shape[:2]
            text_sum = np.sum(mask)
            cannyed = cv2.Canny(img, 70, 140, L2gradient=True, apertureSize=3)
            br = cv2.boundingRect(cv2.findNonZero(mask))
            br_xyxy = [br[0], br[1], br[0] + br[2], br[1] + br[3]]

            cv2.rectangle(cannyed, (0, 0), (w-1, h-1), (255, 255, 255), 1, cv2.LINE_8)
            cannyed = cv2.bitwise_and(cannyed, 255 - mask)

            cons, _ = cv2.findContours(cannyed, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
            min_ballon_area = w * h
            ballon_mask = None
            non_text_mask = None
            for ii, con in enumerate(cons):
                br_c = cv2.boundingRect(con)
                br_c = [br_c[0], br_c[1], br_c[0] + br_c[2], br_c[1] + br_c[3]]
                if br_c[0] > br_xyxy[0] or br_c[1] > br_xyxy[1] or br_c[2] < br_xyxy[2] or br_c[3] < br_xyxy[3]:
                    continue
                tmp = np.zeros_like(cannyed)
                cv2.drawContours(tmp, cons, ii, (255, 255, 255), -1, cv2.LINE_8)
                if cv2.bitwise_and(tmp, mask).sum() >= text_sum:
                    con_area = cv2.contourArea(con)
                    if con_area < min_ballon_area:
                        min_ballon_area = con_area
                        ballon_mask = tmp
            if ballon_mask is not None:
                non_text_mask = cv2.bitwise_and(ballon_mask, 255 - mask)
            #     cv2.imshow('ballon', ballon_mask)
            #     cv2.imshow('non_text', non_text_mask)
            # cv2.imshow('im', img)
            # cv2.imshow('msk', mask)
            # cv2.imshow('br', mask[br_xyxy[1]:br_xyxy[3], br_xyxy[0]:br_xyxy[2]])
            # cv2.imshow('canny', cannyed)
                
            # cv2.waitKey(0)
            # return msk
            return ballon_mask, non_text_mask


        if not self.inpaint_by_block or textblock_list is None:
            return self._inpaint(img, mask)
@@ -78,21 +39,31 @@ class InpainterBase(ModuleParamParser):
            inpainted = np.copy(img)
            for blk in textblock_list:
                xyxy = blk.xyxy
                xyxy_e = enlarge_window(xyxy, im_w, im_h, ratio=1.5)
                xyxy_e = enlarge_window(xyxy, im_w, im_h, ratio=2)
                im = inpainted[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]]
                msk = mask[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]]
                # ballon_msk, non_text_msk = extract_ballon_mask(im, msk)
                # if ballon_msk is not None:
                #     non_text_region = np.where(non_text_msk > 0)
                #     non_text_px = im[non_text_region]
                #     average_bg_color = np.mean(non_text_px, axis=0)
                #     std = np.std(non_text_px - average_bg_color, axis=0)
                #     print(average_bg_color, std)
                need_inpaint = True
                if self.check_need_inpaint:
                    ballon_msk, non_text_msk = extract_ballon_mask(im, msk)
                    if ballon_msk is not None:
                        non_text_region = np.where(non_text_msk > 0)
                        non_text_px = im[non_text_region]
                        average_bg_color = np.mean(non_text_px, axis=0)
                        std_bgr = np.std(non_text_px - average_bg_color, axis=0)
                        std_max = np.max(std_bgr)
                        inpaint_thresh = 7 if np.std(std_bgr) > 1 else 10
                        if std_max < inpaint_thresh:
                            need_inpaint = False
                            inpainted[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]][np.where(ballon_msk > 0)] = average_bg_color
                        # cv2.imshow('im', im)
                        # cv2.imshow('ballon', ballon_msk)
                        # cv2.imshow('non_text', non_text_msk)
                        # cv2.waitKey(0)
                
                if need_inpaint:
                    inpainted[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]] = self._inpaint(im, msk)

                mask[xyxy[1]:xyxy[3], xyxy[0]:xyxy[2]] = 0
            return inpainted

    def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
+64 −64
Original line number Diff line number Diff line
@@ -168,71 +168,71 @@ def refine_mask(img: np.ndarray, pred_mask: np.ndarray, blk_list: List[TextBlock
        mask_refined[by1: by2, bx1: bx2] = cv2.bitwise_or(mask_refined[by1: by2, bx1: bx2], mask_merged)
    return mask_refined

def extract_textballoon(img, pred_textmsk=None, global_mask=None):
    if len(img.shape) > 2 and img.shape[2] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    im_h, im_w = img.shape[0], img.shape[1]
    hyp_textmsk = np.zeros((im_h, im_w), np.uint8)
    thresh_val, threshed = cv2.threshold(img, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
    xormap_sum = cv2.bitwise_xor(threshed, pred_textmsk).sum()
    neg_threshed = 255 - threshed
    neg_xormap_sum = cv2.bitwise_xor(neg_threshed, pred_textmsk).sum()
    neg_thresh = neg_xormap_sum < xormap_sum
    if neg_thresh:
        threshed = neg_threshed
    thresh_info = {'thresh_val': thresh_val,'neg_thresh': neg_thresh}
    connectivity = 8
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(threshed, connectivity, cv2.CV_16U)
    label_unchanged = np.copy(labels)
    if global_mask is not None:
        labels[np.where(global_mask==0)] = 0
    text_labels = []
    if pred_textmsk is not None:
        text_score_thresh = 0.5
        textbbox_map = np.zeros_like(pred_textmsk)
        for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
            if label_index != 0: # skip background label
                x, y, w, h, area = stat
                area *= 255
                x1, y1, x2, y2 = x, y, x+w, y+h
                label_local = labels[y1: y2, x1: x2]
                label_cordinates = np.where(label_local==label_index)
                tmp_merged = np.zeros((h, w), np.uint8)
                tmp_merged[label_cordinates] = 255
                andmap = cv2.bitwise_and(tmp_merged, pred_textmsk[y1: y2, x1: x2])
                text_score = andmap.sum() / area
                if text_score > text_score_thresh:
                    text_labels.append(label_index)
                    hyp_textmsk[y1: y2, x1: x2][label_cordinates] = 255
    labels = label_unchanged
    bubble_msk = np.zeros((img.shape[0], img.shape[1]), np.uint8)
    bubble_msk[np.where(labels==0)] = 255
    # if lang == LANG_JPN:
    bubble_msk = cv2.erode(bubble_msk, (3, 3), iterations=1)
    line_thickness = 2
    cv2.rectangle(bubble_msk, (0, 0), (im_w, im_h), BLACK, line_thickness, cv2.LINE_8)
    contours, hiers = cv2.findContours(bubble_msk, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
# def extract_textballoon(img, pred_textmsk=None, global_mask=None):
#     if len(img.shape) > 2 and img.shape[2] == 3:
#         img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
#     im_h, im_w = img.shape[0], img.shape[1]
#     hyp_textmsk = np.zeros((im_h, im_w), np.uint8)
#     thresh_val, threshed = cv2.threshold(img, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
#     xormap_sum = cv2.bitwise_xor(threshed, pred_textmsk).sum()
#     neg_threshed = 255 - threshed
#     neg_xormap_sum = cv2.bitwise_xor(neg_threshed, pred_textmsk).sum()
#     neg_thresh = neg_xormap_sum < xormap_sum
#     if neg_thresh:
#         threshed = neg_threshed
#     thresh_info = {'thresh_val': thresh_val,'neg_thresh': neg_thresh}
#     connectivity = 8
#     num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(threshed, connectivity, cv2.CV_16U)
#     label_unchanged = np.copy(labels)
#     if global_mask is not None:
#         labels[np.where(global_mask==0)] = 0
#     text_labels = []
#     if pred_textmsk is not None:
#         text_score_thresh = 0.5
#         textbbox_map = np.zeros_like(pred_textmsk)
#         for label_index, stat, centroid in zip(range(num_labels), stats, centroids):
#             if label_index != 0: # skip background label
#                 x, y, w, h, area = stat
#                 area *= 255
#                 x1, y1, x2, y2 = x, y, x+w, y+h
#                 label_local = labels[y1: y2, x1: x2]
#                 label_cordinates = np.where(label_local==label_index)
#                 tmp_merged = np.zeros((h, w), np.uint8)
#                 tmp_merged[label_cordinates] = 255
#                 andmap = cv2.bitwise_and(tmp_merged, pred_textmsk[y1: y2, x1: x2])
#                 text_score = andmap.sum() / area
#                 if text_score > text_score_thresh:
#                     text_labels.append(label_index)
#                     hyp_textmsk[y1: y2, x1: x2][label_cordinates] = 255
#     labels = label_unchanged
#     bubble_msk = np.zeros((img.shape[0], img.shape[1]), np.uint8)
#     bubble_msk[np.where(labels==0)] = 255
#     # if lang == LANG_JPN:
#     bubble_msk = cv2.erode(bubble_msk, (3, 3), iterations=1)
#     line_thickness = 2
#     cv2.rectangle(bubble_msk, (0, 0), (im_w, im_h), BLACK, line_thickness, cv2.LINE_8)
#     contours, hiers = cv2.findContours(bubble_msk, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)

    brect_area_thresh = im_h * im_w * 0.4
    min_brect_area = np.inf
    ballon_index = -1
    maxium_pixsum = -1
    for ii, contour in enumerate(contours):
        brect = cv2.boundingRect(contours[ii])
        brect_area = brect[2] * brect[3]
        if brect_area > brect_area_thresh and brect_area < min_brect_area:
            tmp_ballonmsk = np.zeros_like(bubble_msk)
            tmp_ballonmsk = cv2.drawContours(tmp_ballonmsk, contours, ii, WHITE, cv2.FILLED)
            andmap_sum = cv2.bitwise_and(tmp_ballonmsk, hyp_textmsk).sum()
            if andmap_sum > maxium_pixsum:
                maxium_pixsum = andmap_sum
                min_brect_area = brect_area
                ballon_index = ii
    if ballon_index != -1:
        bubble_msk = np.zeros_like(bubble_msk)
        bubble_msk = cv2.drawContours(bubble_msk, contours, ballon_index, WHITE, cv2.FILLED)
    hyp_textmsk = cv2.bitwise_and(hyp_textmsk, bubble_msk)
    return hyp_textmsk, bubble_msk, thresh_info, (num_labels, label_unchanged, stats, centroids, text_labels)
#     brect_area_thresh = im_h * im_w * 0.4
#     min_brect_area = np.inf
#     ballon_index = -1
#     maxium_pixsum = -1
#     for ii, contour in enumerate(contours):
#         brect = cv2.boundingRect(contours[ii])
#         brect_area = brect[2] * brect[3]
#         if brect_area > brect_area_thresh and brect_area < min_brect_area:
#             tmp_ballonmsk = np.zeros_like(bubble_msk)
#             tmp_ballonmsk = cv2.drawContours(tmp_ballonmsk, contours, ii, WHITE, cv2.FILLED)
#             andmap_sum = cv2.bitwise_and(tmp_ballonmsk, hyp_textmsk).sum()
#             if andmap_sum > maxium_pixsum:
#                 maxium_pixsum = andmap_sum
#                 min_brect_area = brect_area
#                 ballon_index = ii
#     if ballon_index != -1:
#         bubble_msk = np.zeros_like(bubble_msk)
#         bubble_msk = cv2.drawContours(bubble_msk, contours, ballon_index, WHITE, cv2.FILLED)
#     hyp_textmsk = cv2.bitwise_and(hyp_textmsk, bubble_msk)
#     return hyp_textmsk, bubble_msk, thresh_info, (num_labels, label_unchanged, stats, centroids, text_labels)

# def extract_textballoon_channelwise(img, pred_textmsk, test_grey=True, global_mask=None):
#     c_list = [img[:, :, i] for i in range(3)]
+6 −9
Original line number Diff line number Diff line
import sys, os
import os.path as osp

from typing import Dict, List
import functools
sys.path.append(os.getcwd())

from dl import InpainterBase, AOTInpainter, PatchmatchInpainter
@@ -10,7 +6,6 @@ from utils.io_utils import imread, imwrite, find_all_imgs
from ui.misc import ProjImgTrans

from ui.misc import ProjImgTrans, DLModuleConfig
import json
import numpy as np
import cv2
from tqdm import tqdm
@@ -62,12 +57,14 @@ def test_patchmatch(proj: ProjImgTrans, inpaint_by_block=True, show=False):

if __name__ == '__main__':

    manga_dir = 'data/testpacks/manga'
    comic_dir = 'data/testpacks/comic'
    manga_dir = 'data/testpacks/testpacks/jpn'
    comic_dir = 'data/testpacks/testpacks/eng'
    comic_dir2 = 'data/testpacks/testpacks/eng2'
    manga_proj = ProjImgTrans(manga_dir)

    comic_proj = ProjImgTrans(comic_dir)
    # comic_proj = ProjImgTrans(comic_dir2)
    # test_aot(manga_proj, device='cpu', inpaint_by_block=True)
    test_patchmatch(manga_proj, inpaint_by_block=True)
    test_patchmatch(comic_proj, inpaint_by_block=True)
    
    
+46 −1
Original line number Diff line number Diff line
@@ -344,3 +344,48 @@ def connected_canny_flood(img, show_process=False, inpaint_sdthresh=10, inpaint=
                "bground_bgr": bground_aver,
                "inner_rect": inner_rect}
    return text_mask, paint_res, bub_dict


def extract_ballon_mask(img: np.ndarray, mask: np.ndarray) -> np.ndarray:
    img = cv2.GaussianBlur(img,(3,3),cv2.BORDER_DEFAULT)
    h, w = img.shape[:2]
    text_sum = np.sum(mask)
    
    # _, threshed = cv2.threshold(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
    cannyed = cv2.Canny(img, 70, 140, L2gradient=True, apertureSize=3)
    e_size = 1
    element = cv2.getStructuringElement(cv2.MORPH_RECT, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size))
    cannyed = cv2.dilate(cannyed, element, iterations=1)
    br = cv2.boundingRect(cv2.findNonZero(mask))
    br_xyxy = [br[0], br[1], br[0] + br[2], br[1] + br[3]]

    cv2.rectangle(cannyed, (0, 0), (w-1, h-1), (255, 255, 255), 1, cv2.LINE_8)
    cannyed = cv2.bitwise_and(cannyed, 255 - mask)

    cons, _ = cv2.findContours(cannyed, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
    min_ballon_area = w * h
    ballon_mask = None
    non_text_mask = None
    for ii, con in enumerate(cons):
        br_c = cv2.boundingRect(con)
        br_c = [br_c[0], br_c[1], br_c[0] + br_c[2], br_c[1] + br_c[3]]
        if br_c[0] > br_xyxy[0] or br_c[1] > br_xyxy[1] or br_c[2] < br_xyxy[2] or br_c[3] < br_xyxy[3]:
            continue
        tmp = np.zeros_like(cannyed)
        cv2.drawContours(tmp, cons, ii, (255, 255, 255), -1, cv2.LINE_8)
        if cv2.bitwise_and(tmp, mask).sum() >= text_sum:
            con_area = cv2.contourArea(con)
            if con_area < min_ballon_area:
                min_ballon_area = con_area
                ballon_mask = tmp
    if ballon_mask is not None:
        non_text_mask = cv2.bitwise_and(ballon_mask, 255 - mask)
    #     cv2.imshow('ballon', ballon_mask)
    #     cv2.imshow('non_text', non_text_mask)
    # cv2.imshow('im', img)
    # cv2.imshow('msk', mask)
    # cv2.imshow('br', mask[br_xyxy[1]:br_xyxy[3], br_xyxy[0]:br_xyxy[2]])
    # cv2.imshow('canny', cannyed)
    # cv2.waitKey(0)

    return ballon_mask, non_text_mask
 No newline at end of file