Commit 097ea6c7 authored by dmMaze's avatar dmMaze
Browse files

bugfixes

parent 669b5d41
Loading
Loading
Loading
Loading
+6 −4
Original line number Diff line number Diff line
@@ -76,7 +76,7 @@ class CustomTransformerEncoderLayer(nn.Module):
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.pe = PositionalEncoding(d_model, max_len = 768)
        self.pe = PositionalEncoding(d_model, max_len = 2048)

        self.activation = F.gelu

@@ -400,7 +400,10 @@ class OCR48pxCTC:

        model = OCR(dictionary, 768)
        sd = torch.load(model_path, map_location = 'cpu')
        model.load_state_dict(sd['model'] if 'model' in sd else sd)
        del sd['encoders.layers.0.pe.pe']
        del sd['encoders.layers.1.pe.pe']
        del sd['encoders.layers.2.pe.pe']
        model.load_state_dict(sd['model'] if 'model' in sd else sd, strict=False)
        model.eval()
        if self.device != 'cpu' :
            model = model.to(self.device)
@@ -421,9 +424,8 @@ class OCR48pxCTC:
        for blk_idx, textblk in enumerate(textblk_lst):
            for ii in range(len(textblk)):
                textblk_lst_indices.append(blk_idx)
                regions.append(textblk.get_transformed_region(img, ii, 48))
                regions.append(textblk.get_transformed_region(img, ii, 48, maxwidth=8100))
                region_idx += 1
        # regions = [textblk.get_transformed_region(img, idx, self.text_height) for idx in range(len(textblk))]
        perm = range(len(regions))
        chunck_idx = 0
        for indices in chunks(perm, self.max_chunk_size) :
+1 −4
Original line number Diff line number Diff line
@@ -571,10 +571,7 @@ class OCR32pxModel:
        for blk_idx, textblk in enumerate(textblk_lst):
            for ii in range(len(textblk)):
                textblk_lst_indices.append(blk_idx)
                region = textblk.get_transformed_region(img, ii, self.text_height)
                h, w = region.shape[:2]
                if w > 3064:    # positional embedding requires width <= 3072
                    region = region[:, :3064]
                region = textblk.get_transformed_region(img, ii, self.text_height, maxwidth=3064)
                regions.append(region)
                region_idx += 1
        # regions = [textblk.get_transformed_region(img, idx, self.text_height) for idx in range(len(textblk))]
+5 −3
Original line number Diff line number Diff line
@@ -185,7 +185,7 @@ class TextBlock(object):
        blk_dict = copy.deepcopy(vars(self))
        return blk_dict

    def get_transformed_region(self, img, idx, textheight) -> np.ndarray :
    def get_transformed_region(self, img, idx, textheight, maxwidth=None) -> np.ndarray :
        im_h, im_w = img.shape[:2]
        direction = 'v' if self.vertical else 'h'
        src_pts = np.array(self.lines[idx], dtype=np.float64)
@@ -208,8 +208,10 @@ class TextBlock(object):
            M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
            region = cv2.warpPerspective(img, M, (w, h))
            region = cv2.rotate(region, cv2.ROTATE_90_COUNTERCLOCKWISE)
        # cv2.imshow('region'+str(idx), region)
        # cv2.waitKey(0)
        if maxwidth is not None:
            h, w = region.shape[: 2]
            if w > maxwidth:
                region = cv2.resize(region, (maxwidth, h))
        return region

    def get_text(self):
+0 −1
Original line number Diff line number Diff line
@@ -7,7 +7,6 @@ import json
import math
import platform
import warnings
from collections import OrderedDict, namedtuple
from copy import copy
from pathlib import Path

+5 −7
Original line number Diff line number Diff line
@@ -4,7 +4,6 @@ from docx import Document
import piexif.helper
import numpy as np
import os.path as osp
from collections import OrderedDict
from typing import Tuple, Union, List, Dict

from utils.logger import logger as LOGGER
@@ -93,20 +92,19 @@ class ProjImgTrans:
            self.pages = {}
            self._pagename2idx = {}
            self._idx2pagename = {}
            self.not_found_pages = {}
            page_dict = proj_dict['pages']
            not_found_pages = list(page_dict.keys())
            found_pages = find_all_imgs(img_dir=self.directory, abs_path=False)
            page_counter = 0
            for imname in found_pages:
            for ii, imname in enumerate(found_pages):
                if imname in page_dict:
                    self.pages[imname] = [TextBlock(**blk_dict) for blk_dict in page_dict[imname]]
                    self._pagename2idx[imname] = page_counter
                    self._idx2pagename[page_counter] = imname
                    page_counter += 1
                    not_found_pages.remove(imname)
                else:
                    self.pages[imname] = []
                    self.new_pages.append(imname)
                self._pagename2idx[imname] = ii
                self._idx2pagename[ii] = imname
            for imname in not_found_pages:
                self.not_found_pages[imname] = [TextBlock(**blk_dict) for blk_dict in page_dict[imname]]
        except Exception as e:
@@ -330,7 +328,7 @@ class ProjImgTrans:
            shutil.rmtree(cuts_dir)


    def load_doc(self, doc_path, delete_tmp_folder=True, fin_page_signal=None) -> OrderedDict:
    def load_doc(self, doc_path, delete_tmp_folder=True, fin_page_signal=None):

        tmp_bubble_folder = osp.join(self.directory, 'img_folder')
        os.makedirs(tmp_bubble_folder, exist_ok=True)
Loading