Commit 6e90d3d2 authored by dmMaze's avatar dmMaze
Browse files

inpaint by block for cpu mode

parent 0464e6ff
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -8,8 +8,7 @@ data/testpacks/eng_dontupload
tmp.py
dummy_scripts.py

dist
PACKAGES
tmp

*.json
.vscode

data/pt2px.png

deleted100644 → 0
−52.6 KiB
Loading image diff...

data/px2pt.png

deleted100644 → 0
−5.01 KiB
Loading image diff...
+36 −8
Original line number Diff line number Diff line
@@ -2,16 +2,21 @@ import numpy as np
import cv2
from typing import Dict, List

from ..textdetector import TextBlock


from utils.registry import Registry
from utils.textblock_mask import canny_flood, connected_canny_flood
from utils.imgproc_utils import enlarge_window

INPAINTERS = Registry('inpainters')
register_inpainter = INPAINTERS.register_module

from ..moduleparamparser import ModuleParamParser, DEFAULT_DEVICE
from ..textdetector import TextBlock

class InpainterBase(ModuleParamParser):

    inpaint_by_block = True
    def __init__(self, **setup_params) -> None:
        super().__init__(**setup_params)
        self.name = ''
@@ -25,6 +30,20 @@ class InpainterBase(ModuleParamParser):
        raise NotImplementedError

    def inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
        if not self.inpaint_by_block or textblock_list is None:
            return self._inpaint(img, mask)
        else:
            im_h, im_w = img.shape[:2]
            inpainted = img
            for blk in textblock_list:
                xyxy = blk.xyxy
                xyxy_e = enlarge_window(xyxy, im_w, im_h, ratio=1.5)
                im = inpainted[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]]
                msk = mask[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]]
                inpainted[xyxy_e[1]:xyxy_e[3], xyxy_e[0]:xyxy_e[2]] = self._inpaint(im, msk)
            return inpainted

    def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
        raise NotImplementedError


@@ -34,7 +53,7 @@ class OpenCVInpainter(InpainterBase):
    def setup_inpainter(self):
        self.inpaint_method = lambda img, mask, *args, **kwargs: cv2.inpaint(img, mask, 3, cv2.INPAINT_NS)
    
    def inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
    def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
        return self.inpaint_method(img, mask)


@@ -45,7 +64,7 @@ class PatchmatchInpainter(InpainterBase):
        from . import patch_match
        self.inpaint_method = lambda img, mask, *args, **kwargs: patch_match.inpaint(img, mask, patch_size=3)
    
    def inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
    def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
        return self.inpaint_method(img, mask)


@@ -98,6 +117,7 @@ class AOTInpainter(InpainterBase):
        else:
            self.model = AOTMODEL
            self.model.to(self.device)
        self.inpaint_by_block = True if self.device == 'cuda' else False
        self.inpaint_size = int(self.setup_params['inpaint_size']['select'])

    def inpaint_preprocess(self, img: np.ndarray, mask: np.ndarray) -> np.ndarray:
@@ -131,7 +151,7 @@ class AOTInpainter(InpainterBase):
        return img_torch, mask_torch, img_original, mask_original, pad_bottom, pad_right

    @torch.no_grad()
    def inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:
    def _inpaint(self, img: np.ndarray, mask: np.ndarray, textblock_list: List[TextBlock] = None) -> np.ndarray:

        im_h, im_w = img.shape[:2]
        img_torch, mask_torch, img_original, mask_original, pad_bottom, pad_right = self.inpaint_preprocess(img, mask)
@@ -150,7 +170,15 @@ class AOTInpainter(InpainterBase):

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)

        if param_key == 'device':
            param_device = self.setup_params['device']['select']
            self.model.to(param_device)
            self.device = param_device
            if param_device == 'cuda':
                self.inpaint_by_block = False
            else:
                self.inpaint_by_block = True
                
        elif param_key == 'inpaint_size':
            self.inpaint_size = int(self.setup_params['inpaint_size']['select'])
 No newline at end of file
+1 −0
Original line number Diff line number Diff line
@@ -4,4 +4,5 @@ torch
transformers
fugashi
unidic_lite
tqdm
opencv-python>=4.5.4
 No newline at end of file
Loading