Commit 22f9a0c0 authored by dmMaze's avatar dmMaze
Browse files

fix cpu text mask disalignment

parent d12b8f61
Loading
Loading
Loading
Loading
+7 −6
Original line number Diff line number Diff line
@@ -195,6 +195,8 @@ def model2annotations(model_path, img_dir_list, save_dir, save_json=False):
        cv2.imwrite(osp.join(save_dir, maskname), mask_refined)

def preprocess_img(img, detect_size=(1024, 1024), device='cpu', bgr2rgb=True, half=False, to_tensor=True):
    if isinstance(detect_size, int):
        detect_size = (detect_size, detect_size)
    if bgr2rgb:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_in, ratio, (dw, dh) = letterbox(img, new_shape=detect_size, auto=False, stride=64)
@@ -248,8 +250,6 @@ class TextDetector:
        self.net: Union[TextDetBase, TextDetBaseDNN] = None
        self.backend: str = None
        
        if isinstance(detect_size, int):
            detect_size = (detect_size, detect_size)
        self.detect_size = detect_size
        self.device = device
        self.half = half
@@ -303,21 +303,22 @@ class TextDetector:

    @torch.no_grad()
    def __call__(self, img, refine_mode=REFINEMASK_INPAINT, keep_undetected_mask=False) -> Tuple[np.ndarray, np.ndarray, List[TextBlock]]:
        
        detect_size = self.detect_size if not self.backend == 'opencv' else 1024
        im_h, im_w = img.shape[:2]
        lines_map, mask = det_rearrange_forward(img, self.det_batch_forward_ctd, self.detect_size[0], self.det_rearrange_max_batches, self.device)
        lines_map, mask = det_rearrange_forward(img, self.det_batch_forward_ctd, detect_size, self.det_rearrange_max_batches, self.device)
        blks = []
        resize_ratio = [1, 1]
        if lines_map is None:
            img_in, ratio, dw, dh = preprocess_img(img, detect_size=self.detect_size, device=self.device, half=self.half, to_tensor=self.backend=='torch')
            img_in, ratio, dw, dh = preprocess_img(img, detect_size=detect_size, device=self.device, half=self.half, to_tensor=self.backend=='torch')
            blks, mask, lines_map = self.net(img_in)

            if self.backend == 'opencv':
                if mask.shape[1] == 2:     # some version of opencv spit out reversed result
                    tmp = mask
                    mask = lines_map
                    lines_map = tmp
            mask = mask.squeeze()
            resize_ratio = (im_w / (self.detect_size[0] - dw), im_h / (self.detect_size[1] - dh))
            resize_ratio = (im_w / (detect_size - dw), im_h / (detect_size - dh))
            blks = postprocess_yolo(blks, self.conf_thresh, self.nms_thresh, resize_ratio)
            mask = mask[..., :mask.shape[0]-dh, :mask.shape[1]-dw]
            lines_map = lines_map[..., :lines_map.shape[2]-dh, :lines_map.shape[3]-dw]