Commit d12b8f61 authored by dmMaze's avatar dmMaze
Browse files

Rearrange image to handle extreme aspect ratio

parent a495e928
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ tmp.py
dummy_scripts.py

tmp
TEST*

*.json
.vscode
+173 −26
Original line number Diff line number Diff line
@@ -6,20 +6,141 @@ import numpy as np
import cv2
import torch
from pathlib import Path
import torch
import einops

from utils.io_utils import find_all_imgs, NumpyEncoder
from utils.imgproc_utils import letterbox, xyxy2yolo, get_yololabel_strings
from utils.imgproc_utils import letterbox, xyxy2yolo, get_yololabel_strings, square_pad_resize

from ..yolov5.yolov5_utils import non_max_suppression
from ..db_utils import SegDetectorRepresenter
from ..textblock import TextBlock, group_output
from .textmask import refine_mask, refine_undetected_mask, REFINEMASK_INPAINT, REFINEMASK_ANNOTATION
from pathlib import Path
from typing import Union, List, Tuple
from typing import Union, List, Tuple, Callable

CTD_MODEL_PATH = r'data/models/comictextdetector.pt'

def det_rearrange_forward(
    img: np.ndarray, 
    dbnet_batch_forward: Callable[[np.ndarray, str], Tuple[np.ndarray, np.ndarray]], 
    tgt_size: int = 1280, 
    max_batch_size: int = 4, 
    device='cuda', verbose=False):
    '''
    Rearrange image to square batches before feeding into network if following conditions are satisfied: \n
    1. Extreme aspect ratio
    2. Is too tall or wide for detect size (tgt_size)

    Returns:
        DBNet output, mask or None, None if rearrangement is not required
    '''

    def _unrearrange(patch_lst: List[np.ndarray], transpose: bool, channel=1, pad_num=0):
        _psize = _h = patch_lst[0].shape[-1]
        _step = int(ph_step * _psize / patch_size)
        _pw = int(_psize / pw_num)
        if ph_num > 1:
            _h += (ph_num - 1) * _step
        tgtmap = np.zeros((channel, _h, _pw), dtype=np.float32)
        num_patches = len(patch_lst) * pw_num - pad_num
        for ii, p in enumerate(patch_lst):
            if transpose:
                p = einops.rearrange(p, 'c h w -> c w h')
            for jj in range(pw_num):
                pidx = ii * pw_num + jj
                t = pidx * _step
                b = t + _psize
                l = jj * _pw
                r = l + _pw

                tgtmap[..., t: b, :] += p[..., l: r]
                if pidx > 0:
                    interleave = _psize - _step
                    tgtmap[..., t: t+interleave, :] /= 2.

                if pidx >= num_patches - 1:
                    break

        if transpose:
            tgtmap = einops.rearrange(tgtmap, 'c h w -> c w h')
        return tgtmap[None, ...]

    def _patch2batches(patch_lst: List[np.ndarray], p_num: int, transpose: bool):
        if transpose:
            patch_lst = einops.rearrange(patch_lst, '(p_num pw_num) ph pw c -> p_num (pw_num pw) ph c', p_num=p_num)
        else:
            patch_lst = einops.rearrange(patch_lst, '(p_num pw_num) ph pw c -> p_num ph (pw_num pw) c', p_num=p_num)
        
        batches = [[]]
        for ii, patch in enumerate(patch_lst):

            if len(batches[-1]) >= max_batch_size:
                batches.append([])
            p, down_scale_ratio, pad_h, pad_w = square_pad_resize(patch, tgt_size=tgt_size)

            assert pad_h == pad_w
            pad_size = pad_h
            batches[-1].append(p)
            if verbose:
                cv2.imwrite(f'result/rearrange_{ii}.jpg', p[..., ::-1])
        return batches, down_scale_ratio, pad_size

    h, w = img.shape[:2]
    transpose = False
    if h < w:
        transpose = True
        h, w = img.shape[1], img.shape[0]

    asp_ratio = h / w
    down_scale_ratio = h / tgt_size

    # rearrange condition
    require_rearrange = down_scale_ratio > 2.5 and asp_ratio > 3
    if not require_rearrange:
        return None, None
    else:

        if verbose:
            print(f'Input image will be rearranged to square batches before fed into network.\
                \n Rearranged batches will be saved to result/rearrange_%d.jpg')

        if transpose:
            img = einops.rearrange(img, 'h w c -> w h c')
        
        pw_num = max(int(np.floor(2 * tgt_size / w)), 2)
        patch_size = ph = pw_num * w

        ph_num = int(np.ceil(h / ph))
        ph_step = int((h - ph) / (ph_num - 1)) if ph_num > 1 else 0
        patch_list = []
        for ii in range(ph_num):
            patch_list.append(img[ii * ph_step: ii * ph_step + ph])

        p_num = int(np.ceil(ph_num / pw_num))
        pad_num = p_num * pw_num - ph_num
        for ii in range(pad_num):
            patch_list.append(np.zeros_like(patch_list[0]))

        batches, down_scale_ratio, pad_size = _patch2batches(patch_list, p_num, transpose)

        db_lst, mask_lst = [], []
        for batch in batches:
            batch = np.array(batch)
            db, mask = dbnet_batch_forward(batch, device=device)

            for d, m in zip(db, mask):
                if pad_size > 0:
                    paddb = int(db.shape[-1] / tgt_size * pad_size)
                    padmsk = int(mask.shape[-1] / tgt_size * pad_size)
                    d = d[..., :-paddb, :-paddb]
                    m = m[..., :-padmsk, :-padmsk]
                db_lst.append(d)
                mask_lst.append(m)

        db = _unrearrange(db_lst, transpose, channel=2, pad_num=pad_num)
        mask = _unrearrange(mask_lst, transpose, channel=1, pad_num=pad_num)
        return db, mask

def model2annotations(model_path, img_dir_list, save_dir, save_json=False):
    if isinstance(img_dir_list, str):
        img_dir_list = [img_dir_list]
@@ -91,7 +212,7 @@ def postprocess_mask(img: Union[torch.Tensor, np.ndarray], thresh=None):
    if isinstance(img, torch.Tensor):
        img = img.squeeze_()
        if img.device != 'cpu':
            img = img.detach_().cpu()
            img = img.detach().cpu()
        img = img.numpy()
    else:
        img = img.squeeze()
@@ -121,7 +242,7 @@ class TextDetector:
    lang_list = ['eng', 'ja', 'unknown']
    langcls2idx = {'eng': 0, 'ja': 1, 'unknown': 2}

    def __init__(self, model_path, detect_size=1024, device='cpu', half=False, nms_thresh=0.35, conf_thresh=0.4):
    def __init__(self, model_path, detect_size=1024, device='cpu', half=False, nms_thresh=0.35, conf_thresh=0.4, det_rearrange_max_batches=4):
        super(TextDetector, self).__init__()

        self.net: Union[TextDetBase, TextDetBaseDNN] = None
@@ -139,6 +260,8 @@ class TextDetector:
        self.backend = 'torch'
        self.load_model(model_path)

        self.det_rearrange_max_batches = det_rearrange_max_batches

    def load_model(self, model_path: str):
        if Path(model_path).suffix == '.onnx':
            self.net = TextDetBaseDNN(1024, model_path)
@@ -155,37 +278,61 @@ class TextDetector:
            raise FileNotFoundError(f'CTD model not found: {model_path}')
        self.load_model(model_path)

    def det_batch_forward_ctd(self, batch: np.ndarray, device: str) -> Tuple[np.ndarray, np.ndarray]:
        
        if isinstance(self.net, TextDetBase):
            batch = einops.rearrange(batch.astype(np.float32) / 255., 'n h w c -> n c h w')
            batch = torch.from_numpy(batch).to(device)
            _, mask, lines = self.net(batch)
            mask = mask.cpu().numpy()
            lines = lines.cpu().numpy()
        elif isinstance(self.net, TextDetBaseDNN):
            mask_lst, line_lst = [], []
            for b in batch:
                _, mask, lines = self.net(b)
                if mask.shape[1] == 2:     # some version of opencv spit out reversed result
                    tmp = mask
                    mask = lines
                    lines = tmp
                mask_lst.append(mask)
                line_lst.append(lines)
            lines, mask = np.concatenate(line_lst, 0), np.concatenate(mask_lst, 0)
        else:
            raise NotImplementedError
        return lines, mask

    @torch.no_grad()
    def __call__(self, img, refine_mode=REFINEMASK_INPAINT, keep_undetected_mask=False) -> Tuple[np.ndarray, np.ndarray, List[TextBlock]]:
        detect_size = self.detect_size if self.backend == 'torch' else (1024, 1024)
        img_in, ratio, dw, dh = preprocess_img(img, detect_size=detect_size, device=self.device, half=self.half, to_tensor=self.backend=='torch')

        im_h, im_w = img.shape[:2]

        lines_map, mask = det_rearrange_forward(img, self.det_batch_forward_ctd, self.detect_size[0], self.det_rearrange_max_batches, self.device)
        blks = []
        resize_ratio = [1, 1]
        if lines_map is None:
            img_in, ratio, dw, dh = preprocess_img(img, detect_size=self.detect_size, device=self.device, half=self.half, to_tensor=self.backend=='torch')
            blks, mask, lines_map = self.net(img_in)

            if self.backend == 'opencv':
                if mask.shape[1] == 2:     # some version of opencv spit out reversed result
                    tmp = mask
                    mask = lines_map
                    lines_map = tmp
        
        resize_ratio = (im_w / (detect_size[0] - dw), im_h / (detect_size[1] - dh))
            mask = mask.squeeze()
            resize_ratio = (im_w / (self.detect_size[0] - dw), im_h / (self.detect_size[1] - dh))
            blks = postprocess_yolo(blks, self.conf_thresh, self.nms_thresh, resize_ratio)
            mask = mask[..., :mask.shape[0]-dh, :mask.shape[1]-dw]
            lines_map = lines_map[..., :lines_map.shape[2]-dh, :lines_map.shape[3]-dw]

        mask = postprocess_mask(mask)
        lines, scores = self.seg_rep(detect_size, lines_map)
        lines, scores = self.seg_rep(None, lines_map, height=im_h, width=im_w)
        box_thresh = 0.6
        idx = np.where(scores[0] > box_thresh)
        lines, scores = lines[0][idx], scores[0][idx]

        # map output to input img
        mask = mask[: mask.shape[0]-dh, : mask.shape[1]-dw]
        mask = cv2.resize(mask, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
        if lines.size == 0:
            lines = []
        else:
            lines = lines.astype(np.float64)
            lines[..., 0] *= resize_ratio[0]
            lines[..., 1] *= resize_ratio[1]
            lines = lines.astype(np.int32)
        blk_list = group_output(blks, lines, im_w, im_h, mask)
        mask_refined = refine_mask(img, mask, blk_list, refine_mode=refine_mode)
+3 −2
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from typing import List
import cv2
import numpy as np
from ..textblock import TextBlock
from utils.imgproc_utils import draw_connected_labels, expand_textwindow, union_area
from utils.imgproc_utils import draw_connected_labels, expand_textwindow, union_area, enlarge_window

WHITE = (255, 255, 255)
BLACK = (0, 0, 0)
@@ -159,7 +159,8 @@ def refine_undetected_mask(img: np.ndarray, mask_pred: np.ndarray, mask_refined:
def refine_mask(img: np.ndarray, pred_mask: np.ndarray, blk_list: List[TextBlock], refine_mode: int = REFINEMASK_INPAINT) -> np.ndarray:
    mask_refined = np.zeros_like(pred_mask)
    for blk in blk_list:
        bx1, by1, bx2, by2 = expand_textwindow(img.shape, blk.xyxy, expand_r=16)
        # bx1, by1, bx2, by2 = expand_textwindow(img.shape, blk.xyxy, expand_r=16)
        bx1, by1, bx2, by2 = enlarge_window(blk.xyxy, img.shape[1], img.shape[0])
        im = np.ascontiguousarray(img[by1: by2, bx1: bx2])
        msk = np.ascontiguousarray(pred_mask[by1: by2, bx1: bx2])
        mask_list = get_topk_masklist(im, msk)
+7 −3
Original line number Diff line number Diff line
@@ -37,7 +37,7 @@ class SegDetectorRepresenter():
        self.max_candidates = max_candidates
        self.unclip_ratio = unclip_ratio

    def __call__(self, batch, pred, is_output_polygon=False):
    def __call__(self, batch, pred, is_output_polygon=False, height=None, width=None):
        '''
        batch: (image, polygons, ignore_tags
        batch: a dict produced by dataloaders.
@@ -57,9 +57,13 @@ class SegDetectorRepresenter():
        scores_batch = []
        # print(pred.size())
        batch_size = pred.size(0) if isinstance(pred, torch.Tensor) else pred.shape[0]

        if height is None:
            height = pred.shape[1]
        if width is None: 
            width = pred.shape[2]

        for batch_index in range(batch_size):
            # height, width = batch['shape'][batch_index]
            height, width = pred.shape[1], pred.shape[2]
            if is_output_polygon:
                boxes, scores = self.polygons_from_bitmap(pred[batch_index], segmentation[batch_index], width, height)
            else:
+6 −4
Original line number Diff line number Diff line
@@ -363,7 +363,7 @@ def examine_textblk(blk: TextBlock, im_w: int, im_h: int, sort: bool = False) ->
    if sort:
        blk.sort_lines()

def try_merge_textline(blk: TextBlock, blk2: TextBlock, fntsize_tol=1.3, distance_tol=2) -> bool:
def try_merge_textline(blk: TextBlock, blk2: TextBlock, fntsize_tol=1.4, distance_tol=2) -> bool:
    if blk2.merged:
        return False
    fntsize_div = blk.font_size / blk2.font_size
@@ -380,7 +380,9 @@ def try_merge_textline(blk: TextBlock, blk2: TextBlock, fntsize_tol=1.3, distanc
            return False
        if abs(cos_vec) < 0.866:   # cos30
            return False
        if distance > distance_tol * fntsz_avg or distance_p1 > fntsz_avg * 2.5:
        if distance > distance_tol * fntsz_avg:
            return False
        if blk.vertical and blk2.vertical and distance_p1 > fntsz_avg * 2.5:
            return False
    # merge
    blk.lines.append(blk2.lines[0])
@@ -512,12 +514,12 @@ def group_output(blks, lines, im_w, im_h, mask=None, sort_blklist=True) -> List[
        final_blk_list = sort_textblk_list(final_blk_list, im_w, im_h)

    for blk in final_blk_list:
        if blk.language == 'eng' and not blk.vertical:
        if blk.language != 'ja' and not blk.vertical:
            num_lines = len(blk.lines)
            if num_lines == 0:
                continue
            # blk.line_spacing = blk.bounding_rect()[3] / num_lines / blk.font_size
            expand_size = max(int(blk.font_size * 0.1), 2)
            expand_size = max(int(blk.font_size * 0.1), 3)
            rad = np.deg2rad(blk.angle)
            shifted_vec = np.array([[[-1, -1],[1, -1],[1, 1],[-1, 1]]])
            shifted_vec = shifted_vec * np.array([[[np.sin(rad), np.cos(rad)]]]) * expand_size
Loading