Commit ae095dbd authored by dmMaze's avatar dmMaze
Browse files

cleanup mit* ocr

parent a4b1b4d0
Loading
Loading
Loading
Loading
+1 −139
Original line number Diff line number Diff line
from typing import List
import numpy as np

from .base import OCRBase, register_OCR, DEVICE_SELECTOR, DEFAULT_DEVICE, TextBlock, OCR
 No newline at end of file

from .model_32px import OCR32pxModel
@register_OCR('mit32px')
class OCRMIT32px(OCRBase):
    params = {
        'chunk_size': {
            'type': 'selector',
            'options': [8, 16, 24, 32],
            'value': 16
        },
        'device': DEVICE_SELECTOR(),
        'description': 'OCRMIT32px'
    }
    device = DEFAULT_DEVICE
    chunk_size = 16

    download_file_list = [{
        'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr.zip',
        'files': ['ocr.ckpt'],
        'sha256_pre_calculated': ['d9f619a9dccce8ce88357d1b17d25f07806f225c033ea42c64e86c45446cfe71'],
        'save_files': ['data/models/mit32px_ocr.ckpt'],
        'archived_files': 'ocr.zip',
        'archive_sha256_pre_calculated': '47405638b96fa2540a5ee841a4cd792f25062c09d9458a973362d40785f95d7a',
    }]
    _load_model_keys = {'model'}

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.device = self.params['device']['value']
        self.chunk_size = int(self.params['chunk_size']['value'])
        self.model: OCR32pxModel = None

    def _load_model(self):
        self.model = OCR32pxModel(r'data/models/mit32px_ocr.ckpt', self.device, self.chunk_size)

    def ocr_img(self, img: np.ndarray) -> str:
        return self.model.ocr_img(img)

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        return self.model(img, blk_list)

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
        device = self.params['device']['value']
        chunk_size = int(self.params['chunk_size']['value'])
        if self.device != device:
            self.model.to(device)
        self.chunk_size = chunk_size
        self.model.max_chunk_size = chunk_size


from .mit48px_ctc import OCR48pxCTC
@register_OCR('mit48px_ctc')
class OCRMIT48pxCTC(OCRBase):
    params = {
        'chunk_size': {
            'type': 'selector',
            'options': [8,16,24,32],
            'value': 16
        },
        'device': DEVICE_SELECTOR(),
        'description': 'mit48px_ctc'
    }
    device = DEFAULT_DEVICE
    chunk_size = 16

    download_file_list = [{
        'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/ocr-ctc.zip',
        'files': ['ocr-ctc.ckpt', 'alphabet-all-v5.txt'],
        'sha256_pre_calculated': ['8b0837a24da5fde96c23ca47bb7abd590cd5b185c307e348c6e0b7238178ed89', None],
        'save_files': ['data/models/mit48pxctc_ocr.ckpt', 'data/alphabet-all-v5.txt'],
        'archived_files': 'ocr-ctc.zip',
        'archive_sha256_pre_calculated': 'fc61c52f7a811bc72c54f6be85df814c6b60f63585175db27cb94a08e0c30101',
    }]
    _load_model_keys = {'model'}

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.device = self.params['device']['value']
        self.chunk_size = int(self.params['chunk_size']['value'])
        self.model: OCR48pxCTC = None

    def _load_model(self):
        self.model = OCR48pxCTC(r'data/models/mit48pxctc_ocr.ckpt', self.device, self.chunk_size)

    def ocr_img(self, img: np.ndarray) -> str:
        return self.model.ocr_img(img)

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        return self.model(img, blk_list)

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
        device = self.params['device']['value']
        chunk_size = int(self.params['chunk_size']['value'])
        if self.device != device:
            self.model.to(device)
        self.chunk_size = chunk_size
        self.model.max_chunk_size = chunk_size


from .mit48px import Model48pxOCR
OCR48PXMODEL_PATH = r'data/models/ocr_ar_48px.ckpt'
@register_OCR('mit48px')
class OCRMIT48px(OCRBase):
    params = {
        'device': DEVICE_SELECTOR(),
        'description': 'mit48px'
    }
    device = DEFAULT_DEVICE

    download_file_list = [{
        'url': 'https://huggingface.co/zyddnys/manga-image-translator/resolve/main/',
        'files': [OCR48PXMODEL_PATH, 'data/alphabet-all-v7.txt'],
        'sha256_pre_calculated': ['29daa46d080818bb4ab239a518a88338cbccff8f901bef8c9db191a7cb97671d', None],
        'concatenate_url_filename': 2,
    }]
    _load_model_keys = {'model'}

    def __init__(self, **params) -> None:
        super().__init__(**params)
        self.device = self.params['device']['value']
        self.model: Model48pxOCR = None

    def _load_model(self):
        self.model = Model48pxOCR(OCR48PXMODEL_PATH, self.device)

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]):
        return self.model(img, blk_list)

    def updateParam(self, param_key: str, param_content):
        super().updateParam(param_key, param_content)
        device = self.params['device']['value']
        if self.device != device:
            self.model.to(device)
 No newline at end of file
+3 −9
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@ class OCRBase(BaseModule):

    _postprocess_hooks = OrderedDict()
    _preprocess_hooks = OrderedDict()
    _line_only: bool = False

    def __init__(self, **params) -> None:
        super().__init__(**params)
@@ -41,17 +42,10 @@ class OCRBase(BaseModule):
        self._ocr_blk_list(img, blk_list)
        for callback_name, callback in self._postprocess_hooks.items():
            callback(textblocks=blk_list, img=img, ocr_module=self)
        # for blk in blk_list:
        #     if isinstance(blk.text, List):
        #         for ii, t in enumerate(blk.text):
        #             for callback in self.postprocess_hooks:
        #                 blk.text[ii] = callback(t, blk=blk)
        #     else:
        #         for callback in self.postprocess_hooks:
        #             blk.text = callback(blk.text, blk=blk)

        return blk_list

    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock]) -> None:
    def _ocr_blk_list(self, img: np.ndarray, blk_list: List[TextBlock], *args, **kwargs) -> None:
        raise NotImplementedError

    def ocr_img(self, img: np.ndarray) -> str:
+17 −19
Original line number Diff line number Diff line
@@ -119,12 +119,11 @@ class Model48pxOCR:
        },
    }

    def __init__(self, model_path: str, device='cpu', max_chunk_size=16, *args, **kwargs):
    def __init__(self, model_path: str, device='cpu', *args, **kwargs):

        super().__init__(*args, **kwargs)

        self.device = device
        self.max_chunk_size = max_chunk_size

        with open('data/alphabet-all-v7.txt', 'r', encoding = 'utf-8') as fp:
            dictionary = [s[:-1] for s in fp.readlines()]
@@ -139,32 +138,31 @@ class Model48pxOCR:
        self.model.to(device)
        self.device = device
    
    def __call__(self, img: np.ndarray, textblk_lst: List[TextBlock], verbose: bool = False, ignore_bubble: int = 0) -> List[TextBlock]:
    def __call__(self, img: np.ndarray, textblk_lst: List[TextBlock], chunk_size = 16, regions: List = None, textblk_lst_indices: List = None) -> None:
        if isinstance(textblk_lst, TextBlock):
            textblk_lst = [textblk_lst]
        
        text_height = 48
        max_chunk_size = 16

        region_imgs = []
        if regions is None or textblk_lst_indices is None:
            regions = []
            textblk_lst_indices = []
        region_idx = 0
            for blk_idx, textblk in enumerate(textblk_lst):
                for ii in range(len(textblk)):
                    textblk_lst_indices.append(blk_idx)
                region_imgs.append(textblk.get_transformed_region(img, ii, 48, maxwidth=8100))
                region_idx += 1
                    region = textblk.get_transformed_region(img, ii, 48, maxwidth=8100)
                    regions.append(region)

        perm = range(len(region_imgs))
        perm = range(len(regions))
        chunck_idx = 0
        for indices in chunks(perm, max_chunk_size):
        for indices in chunks(perm, chunk_size):
            N = len(indices)
            widths = [region_imgs[i].shape[1] for i in indices]
            widths = [regions[i].shape[1] for i in indices]
            max_width = 4 * (max(widths) + 7) // 4
            region = np.zeros((N, text_height, max_width, 3), dtype = np.uint8)
            for i, idx in enumerate(indices):
                W = region_imgs[idx].shape[1]
                region[i, :, : W, :]=region_imgs[idx]
                W = regions[idx].shape[1]
                region[i, :, : W, :]=regions[idx]

            image_tensor = (torch.from_numpy(region).float() - 127.5) / 127.5
            image_tensor = einops.rearrange(image_tensor, 'N H W C -> N C H W')
+12 −12
Original line number Diff line number Diff line
@@ -392,11 +392,10 @@ class AvgMeter() :

class OCR48pxCTC:

    def __init__(self, model_path: str, device='cpu', max_chunk_size=16):
    def __init__(self, model_path: str, device='cpu'):
        with open('data/alphabet-all-v5.txt', 'r', encoding = 'utf-8') as fp :
            dictionary = [s[:-1] for s in fp.readlines()]
        self.device = device
        self.max_chunk_size = max_chunk_size

        model = OCR(dictionary, 768)
        sd = torch.load(model_path, map_location = 'cpu')
@@ -414,21 +413,22 @@ class OCR48pxCTC:
        self.device = device

    @torch.no_grad()
    def __call__(self, img: np.ndarray, textblk_lst: List[TextBlock]) :
    def __call__(self, img: np.ndarray, textblk_lst: List[TextBlock], chunk_size = 16, regions: List = None, textblk_lst_indices: List = None) -> None:
        if isinstance(textblk_lst, TextBlock):
            textblk_lst = [textblk_lst]
        
        if regions is None or textblk_lst_indices is None:
            regions = []
            textblk_lst_indices = []
        region_idx = 0
            for blk_idx, textblk in enumerate(textblk_lst):
                for ii in range(len(textblk)):
                    textblk_lst_indices.append(blk_idx)
                regions.append(textblk.get_transformed_region(img, ii, 48, maxwidth=8100))
                region_idx += 1
                    region = textblk.get_transformed_region(img, ii, 48, maxwidth=8100)
                    regions.append(region)

        perm = range(len(regions))
        chunck_idx = 0
        for indices in chunks(perm, self.max_chunk_size) :
        for indices in chunks(perm, chunk_size) :
            N = len(indices)
            widths = [regions[i].shape[1] for i in indices]
            # max_width = 4 * (max(widths) + 7) // 4
+13 −14
Original line number Diff line number Diff line
@@ -541,9 +541,8 @@ def chunks(lst, n):


class OCR32pxModel:
    def __init__(self, model_path, device='cpu', max_chunk_size=16) -> None:
    def __init__(self, model_path, device='cpu') -> None:
        self.device = device
        self.max_chunk_size = max_chunk_size
        self.text_height = 32

        self.net = None
@@ -562,22 +561,22 @@ class OCR32pxModel:
        self.device = device

    @torch.no_grad()
    def __call__(self, img: np.ndarray, textblk_lst: List[TextBlock]) -> None:
    def __call__(self, img: np.ndarray, textblk_lst: List[TextBlock], chunk_size = 16, regions: List = None, textblk_lst_indices: List = None) -> None:
        if isinstance(textblk_lst, TextBlock):
            textblk_lst = [textblk_lst]

        if regions is None or textblk_lst_indices is None:
            regions = []
            textblk_lst_indices = []
        region_idx = 0
            for blk_idx, textblk in enumerate(textblk_lst):
                for ii in range(len(textblk)):
                    textblk_lst_indices.append(blk_idx)
                    region = textblk.get_transformed_region(img, ii, self.text_height, maxwidth=3064)
                    regions.append(region)
                region_idx += 1
        # regions = [textblk.get_transformed_region(img, idx, self.text_height) for idx in range(len(textblk))]

        perm = range(len(regions))
        chunck_idx = 0
        for indices in chunks(perm, self.max_chunk_size) :
        for indices in chunks(perm, chunk_size) :
            N = len(indices)
            widths = [regions[i].shape[1] for i in indices]
            max_width = 4 * (max(widths) + 7) // 4
Loading